Skip to content

Commit dbecec2

Browse files
authored
fix: topic filtering and drop DDB Libraries (reconbot#57)
- topic filtering wasn't working correctly - prepping for revamp of the pub/sub system as we can do more better =) - should be trivial to add lookup by topic and connectionId ev refactor DDB to not use any libraries
1 parent b8abad3 commit dbecec2

25 files changed

+353
-392
lines changed

lib/ddb/DDB.ts

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import { DynamoDB } from 'aws-sdk'
2+
import { LoggerFunction, DDBType } from '../types'
3+
4+
export interface DDBClient<T extends DDBType> {
5+
get: (id: string) => Promise<T|null>
6+
put: (Item: T) => Promise<T>
7+
update: (id: string, obj: Partial<T>) => Promise<T>
8+
delete: (id: string) => Promise<T>
9+
query: (options: Omit<DynamoDB.DocumentClient.QueryInput, 'TableName' | 'Select'>) => AsyncGenerator<T, void, undefined>
10+
}
11+
12+
export const DDB = <T extends DDBType>({
13+
dynamodb,
14+
tableName,
15+
log,
16+
}: {
17+
dynamodb: DynamoDB
18+
tableName: string
19+
log: LoggerFunction
20+
}): DDBClient<T> => {
21+
const documentClient = new DynamoDB.DocumentClient({ service: dynamodb })
22+
23+
const get = async (id: string): Promise<null | T> => {
24+
log('get', { tableName: tableName, id })
25+
const { Item } = await documentClient.get({
26+
TableName: tableName,
27+
Key: { id },
28+
}).promise()
29+
return (Item as T) ?? null
30+
}
31+
32+
const put = async (Item: T): Promise<T> => {
33+
log('put', { tableName: tableName, Item })
34+
const { Attributes } = await documentClient.put({
35+
TableName: tableName,
36+
Item,
37+
ReturnValues: 'ALL_OLD',
38+
}).promise()
39+
return Attributes as T
40+
}
41+
42+
const update = async (id: string, obj: Partial<T>) => {
43+
const AttributeUpdates = Object.entries(obj)
44+
.map(([key, Value]) => ({ [key]: { Value, Action: 'PUT' } }))
45+
.reduce((memo, val) => ({ ...memo, ...val }))
46+
47+
const { Attributes } = await documentClient.update({
48+
TableName: tableName,
49+
Key: { id },
50+
AttributeUpdates,
51+
ReturnValues: 'ALL_NEW',
52+
}).promise()
53+
return Attributes as T
54+
}
55+
56+
const deleteFunction = async (id: string): Promise<T> => {
57+
const { Attributes } = await documentClient.delete({
58+
TableName: tableName,
59+
Key: { id },
60+
ReturnValues: 'ALL_OLD',
61+
}).promise()
62+
return Attributes as T
63+
}
64+
65+
const queryOnce = async (options: Omit<DynamoDB.DocumentClient.QueryInput, 'TableName' | 'Select'>) => {
66+
log('queryOnce', options)
67+
68+
const response = await documentClient.query({
69+
TableName: tableName,
70+
Select: 'ALL_ATTRIBUTES',
71+
...options,
72+
}).promise()
73+
74+
const { Items, LastEvaluatedKey, Count } = response
75+
return {
76+
items: (Items ?? []) as T[],
77+
lastEvaluatedKey: LastEvaluatedKey,
78+
count: Count ?? 0,
79+
}
80+
}
81+
82+
async function* query(options: Omit<DynamoDB.DocumentClient.QueryInput, 'TableName' | 'Select'>) {
83+
log('query', options)
84+
const results = await queryOnce(options)
85+
yield* results.items
86+
let lastEvaluatedKey = results.lastEvaluatedKey
87+
while (lastEvaluatedKey) {
88+
const results = await queryOnce({ ...options, ExclusiveStartKey: lastEvaluatedKey })
89+
yield* results.items
90+
lastEvaluatedKey = results.lastEvaluatedKey
91+
}
92+
}
93+
94+
return {
95+
get,
96+
put,
97+
update,
98+
query,
99+
delete: deleteFunction,
100+
}
101+
}

lib/handleStepFunctionEvent.ts

+3-8
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,11 @@ export const handleStepFunctionEvent = (serverPromise: Promise<ServerClosure>):
88
if (!server.pingpong) {
99
throw new Error('Invalid pingpong settings')
1010
}
11-
const connection = Object.assign(new server.model.Connection(), {
12-
id: input.connectionId,
13-
})
1411

1512
// Initial state - send ping message
1613
if (input.state === 'PING') {
1714
await postToConnection(server)({ ...input, message: { type: MessageType.Ping } })
18-
await server.mapper.update(Object.assign(connection, { hasPonged: false }), {
19-
onMissing: 'skip',
20-
})
15+
await server.models.connection.update(input.connectionId, { hasPonged: false })
2116
return {
2217
...input,
2318
state: 'REVIEW',
@@ -26,8 +21,8 @@ export const handleStepFunctionEvent = (serverPromise: Promise<ServerClosure>):
2621
}
2722

2823
// Follow up state - check if pong was returned
29-
const conn = await server.mapper.get(connection)
30-
if (conn.hasPonged) {
24+
const conn = await server.models.connection.get(input.connectionId)
25+
if (conn?.hasPonged) {
3126
return {
3227
...input,
3328
state: 'PING',

lib/index.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ export {
3434
PubSubEvent,
3535
SubscriptionDefinition,
3636
SubscriptionFilter,
37+
Connection,
38+
Subscription,
3739
} from './types'
38-
export { Subscription } from './model/Subscription'
39-
export { Connection } from './model/Connection'

lib/makeServerClosure.ts

+8-16
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,26 @@
1-
import { DataMapper } from '@aws/dynamodb-data-mapper'
2-
import { ServerArgs, ServerClosure } from './types'
3-
import { createModel } from './model/createModel'
4-
import { Subscription } from './model/Subscription'
5-
import { Connection } from './model/Connection'
1+
import { ServerArgs, ServerClosure, Connection, Subscription } from './types'
2+
import { DDB } from './ddb/DDB'
63
import { log as debugLogger } from './utils/logger'
74

85
export const makeServerClosure = async (opts: ServerArgs): Promise<ServerClosure> => {
96
const {
107
tableNames,
118
log = debugLogger,
12-
dynamodb,
9+
dynamodb: dynamodbPromise,
1310
apiGatewayManagementApi,
1411
pingpong,
1512
...rest
1613
} = opts
14+
const dynamodb = await dynamodbPromise
1715
return {
1816
...rest,
1917
apiGatewayManagementApi: await apiGatewayManagementApi,
2018
pingpong: await pingpong,
19+
dynamodb: dynamodb,
2120
log,
22-
model: {
23-
Subscription: createModel({
24-
model: Subscription,
25-
table: (await tableNames)?.subscriptions || 'graphql_subscriptions',
26-
}),
27-
Connection: createModel({
28-
model: Connection,
29-
table: (await tableNames)?.connections || 'graphql_connections',
30-
}),
21+
models: {
22+
subscription: DDB<Subscription>({ dynamodb, tableName: (await tableNames)?.subscriptions || 'graphql_subscriptions', log }),
23+
connection: DDB<Connection>({ dynamodb, tableName: (await tableNames)?.connections || 'graphql_connections', log }),
3124
},
32-
mapper: new DataMapper({ client: await dynamodb }),
3325
}
3426
}

lib/messages/complete.ts

+7-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import AggregateError from 'aggregate-error'
22
import { parse } from 'graphql'
33
import { CompleteMessage } from 'graphql-ws'
44
import { buildExecutionContext } from 'graphql/execution/execute'
5-
import { collect } from 'streaming-iterables'
65
import { SubscribePseudoIterable, MessageHandler, PubSubEvent } from '../types'
76
import { deleteConnection } from '../utils/deleteConnection'
87
import { buildContext } from '../utils/buildContext'
@@ -14,21 +13,17 @@ export const complete: MessageHandler<CompleteMessage> =
1413
async ({ server, event, message }) => {
1514
server.log('messages:complete', { connectionId: event.requestContext.connectionId })
1615
try {
17-
const topicSubscriptions = await collect(server.mapper.query(server.model.Subscription, {
18-
id: `${event.requestContext.connectionId}|${message.id}`,
19-
}))
20-
if (topicSubscriptions.length === 0) {
16+
const subscription = await server.models.subscription.get(`${event.requestContext.connectionId}|${message.id}`)
17+
if (!subscription) {
2118
return
2219
}
23-
// only call onComplete on the first one as any others are duplicates
24-
const sub = topicSubscriptions[0]
2520
const execContext = buildExecutionContext(
2621
server.schema,
27-
parse(sub.subscription.query),
22+
parse(subscription.subscription.query),
2823
undefined,
29-
await buildContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }),
30-
sub.subscription.variables,
31-
sub.subscription.operationName,
24+
await buildContext({ server, connectionInitPayload: subscription.connectionInitPayload, connectionId: subscription.connectionId }),
25+
subscription.subscription.variables,
26+
subscription.subscription.operationName,
3227
undefined,
3328
)
3429

@@ -42,7 +37,7 @@ export const complete: MessageHandler<CompleteMessage> =
4237
server.log('messages:complete:onComplete', { onComplete: !!onComplete })
4338
await onComplete?.(root, args, context, info)
4439

45-
await Promise.all(topicSubscriptions.map(sub => server.mapper.delete(sub)))
40+
await server.models.subscription.delete(subscription.id)
4641
} catch (err) {
4742
server.log('messages:complete:onError', { err, event })
4843
await server.onError?.(err, { event, message })

lib/messages/connection_init.ts

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { ConnectionInitMessage, MessageType } from 'graphql-ws'
33
import { StateFunctionInput, MessageHandler } from '../types'
44
import { postToConnection } from '../utils/postToConnection'
55
import { deleteConnection } from '../utils/deleteConnection'
6+
import { defaultTTL } from '../utils/defaultTTL'
67

78
/** Handler function for 'connection_init' message. */
89
export const connection_init: MessageHandler<ConnectionInitMessage> =
@@ -28,12 +29,14 @@ export const connection_init: MessageHandler<ConnectionInitMessage> =
2829
}
2930

3031
// Write to persistence
31-
const connection = Object.assign(new server.model.Connection(), {
32+
await server.models.connection.put({
3233
id: event.requestContext.connectionId,
34+
createdAt: Date.now(),
3335
requestContext: event.requestContext,
3436
payload,
37+
hasPonged: false,
38+
ttl: defaultTTL(),
3539
})
36-
await server.mapper.put(connection)
3740
return postToConnection(server)({
3841
...event.requestContext,
3942
message: { type: MessageType.ConnectionAck },

lib/messages/disconnect.ts

+9-15
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
11
import AggregateError from 'aggregate-error'
22
import { parse } from 'graphql'
3-
import { equals } from '@aws/dynamodb-expressions'
43
import { buildExecutionContext } from 'graphql/execution/execute'
54
import { buildContext } from '../utils/buildContext'
65
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
76
import { SubscribePseudoIterable, MessageHandler, PubSubEvent } from '../types'
87
import { isArray } from '../utils/isArray'
98
import { collect } from 'streaming-iterables'
10-
import { Connection } from '../model/Connection'
9+
import { Connection } from '../types'
1110

1211
/** Handler function for 'disconnect' message. */
1312
export const disconnect: MessageHandler<null> = async ({ server, event }) => {
1413
server.log('messages:disconnect', { connectionId: event.requestContext.connectionId })
1514
try {
1615
await server.onDisconnect?.({ event })
1716

18-
const topicSubscriptions = await collect(server.mapper.query(
19-
server.model.Subscription,
20-
{
21-
connectionId: equals(event.requestContext.connectionId),
22-
},
23-
{ indexName: 'ConnectionIndex' },
24-
))
17+
const topicSubscriptions = await collect(server.models.subscription.query({
18+
IndexName: 'ConnectionIndex',
19+
ExpressionAttributeNames: { '#a': 'connectionId' },
20+
ExpressionAttributeValues: { ':1': event.requestContext.connectionId },
21+
KeyConditionExpression: '#a = :1',
22+
}))
2523

2624
const completed = {} as Record<string, boolean>
2725
const deletions = [] as Promise<void|Connection>[]
@@ -54,7 +52,7 @@ export const disconnect: MessageHandler<null> = async ({ server, event }) => {
5452
await onComplete?.(root, args, context, info)
5553
}
5654

57-
await server.mapper.delete(sub)
55+
await server.models.subscription.delete(sub.id)
5856
})(),
5957
)
6058
}
@@ -63,11 +61,7 @@ export const disconnect: MessageHandler<null> = async ({ server, event }) => {
6361
// Delete subscriptions
6462
...deletions,
6563
// Delete connection
66-
server.mapper.delete(
67-
Object.assign(new server.model.Connection(), {
68-
id: event.requestContext.connectionId,
69-
}),
70-
),
64+
server.models.connection.delete(event.requestContext.connectionId),
7165
])
7266
} catch (err) {
7367
server.log('messages:disconnect:onError', { err, event })

lib/messages/pong.ts

+3-9
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,9 @@ export const pong: MessageHandler<PongMessage> =
77
async ({ server, event, message }) => {
88
try {
99
await server.onPong?.({ event, message })
10-
await server.mapper.update(
11-
Object.assign(new server.model.Connection(), {
12-
id: event.requestContext.connectionId,
13-
hasPonged: true,
14-
}),
15-
{
16-
onMissing: 'skip',
17-
},
18-
)
10+
await server.models.connection.update(event.requestContext.connectionId, {
11+
hasPonged: true,
12+
})
1913
} catch (err) {
2014
await server.onError?.(err, { event, message })
2115
await deleteConnection(server)(event.requestContext)

lib/messages/subscribe-test.ts

+20-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import { tables } from '@architect/sandbox'
44
import { subscribe } from './subscribe'
55
import { mockServerContext } from '../test/mockServer'
66
import { connection_init } from './connection_init'
7-
import { equals } from '@aws/dynamodb-expressions'
87
import { collect } from 'streaming-iterables'
98
import { subscribe as pubsubSubscribe } from '../pubsub/subscribe'
109
import { makeExecutableSchema } from '@graphql-tools/schema'
@@ -60,8 +59,14 @@ describe('messages/subscribe', () => {
6059
],
6160
delete: [],
6261
})
63-
const [subscriptions] = await collect(server.mapper.query(server.model.Subscription, { connectionId: equals(event.requestContext.connectionId) }, { indexName: 'ConnectionIndex' }))
64-
assert.include(subscriptions, { connectionId, subscriptionId: '1234' })
62+
63+
const [subscriptions] = await collect(server.models.subscription.query({
64+
IndexName: 'ConnectionIndex',
65+
ExpressionAttributeNames: { '#a': 'connectionId' },
66+
ExpressionAttributeValues: { ':1': event.requestContext.connectionId },
67+
KeyConditionExpression: '#a = :1',
68+
}))
69+
assert.containSubset(subscriptions, { connectionId, subscriptionId: '1234' })
6570
})
6671

6772
it('sends errors on error', async () => {
@@ -158,7 +163,12 @@ describe('messages/subscribe', () => {
158163
assert.equal(error.message, 'don\'t subscribe!')
159164
}
160165
assert.deepEqual(onSubscribe, ['We did it!'])
161-
const subscriptions = await collect(server.mapper.query(server.model.Subscription, { connectionId: equals(event.requestContext.connectionId) }, { indexName: 'ConnectionIndex' }))
166+
const subscriptions = await collect(server.models.subscription.query({
167+
IndexName: 'ConnectionIndex',
168+
ExpressionAttributeNames: { '#a': 'connectionId' },
169+
ExpressionAttributeValues: { ':1': event.requestContext.connectionId },
170+
KeyConditionExpression: '#a = :1',
171+
}))
162172
assert.isEmpty(subscriptions)
163173
})
164174

@@ -206,7 +216,12 @@ describe('messages/subscribe', () => {
206216
await connection_init({ server, event: connectionInitEvent, message: JSON.parse(connectionInitEvent.body) })
207217
await subscribe({ server, event, message: JSON.parse(event.body) })
208218
assert.deepEqual(events, ['onSubscribe', 'onAfterSubscribe'])
209-
const subscriptions = await collect(server.mapper.query(server.model.Subscription, { connectionId: equals(event.requestContext.connectionId) }, { indexName: 'ConnectionIndex' }))
219+
const subscriptions = await collect(server.models.subscription.query({
220+
IndexName: 'ConnectionIndex',
221+
ExpressionAttributeNames: { '#a': 'connectionId' },
222+
ExpressionAttributeValues: { ':1': event.requestContext.connectionId },
223+
KeyConditionExpression: '#a = :1',
224+
}))
210225
assert.isNotEmpty(subscriptions)
211226
})
212227
})

0 commit comments

Comments
 (0)