diff --git a/apps/browser-proxy/package.json b/apps/browser-proxy/package.json index d94243da..178abf25 100644 --- a/apps/browser-proxy/package.json +++ b/apps/browser-proxy/package.json @@ -8,6 +8,7 @@ }, "dependencies": { "@aws-sdk/client-s3": "^3.645.0", + "async-mutex": "^0.5.0", "debug": "^4.3.7", "expiry-map": "^2.0.0", "findhit-proxywrap": "^0.3.13", @@ -19,6 +20,7 @@ "@total-typescript/tsconfig": "^1.0.4", "@types/debug": "^4.1.12", "@types/node": "^22.5.4", + "prisma": "^5.19.1", "typescript": "^5.5.4" } } diff --git a/apps/browser-proxy/src/connection-manager.ts b/apps/browser-proxy/src/connection-manager.ts new file mode 100644 index 00000000..358514ae --- /dev/null +++ b/apps/browser-proxy/src/connection-manager.ts @@ -0,0 +1,119 @@ +import { WebSocket } from 'ws' +import { PostgresConnection } from 'pg-gateway' +import makeDebug from 'debug' +import { Mutex } from 'async-mutex' + +const debug = makeDebug('browser-proxy') + +export class ConnectionManager { + private tcpConnections = new Map() + private websocketConnections = new Map() + private activeConnectionIds = new Map() + private mutexes: Map = new Map() + + async processMessage(databaseId: string, connectionId: string, message: Uint8Array) { + const key = `${databaseId}:${connectionId}` + if (!this.mutexes.has(key)) { + this.mutexes.set(key, new Mutex()) + } + const mutex = this.mutexes.get(key)! + + await mutex.runExclusive(async () => { + // Process the message + this.sendMessageToWebSocket(databaseId, message) + }) + } + + addTcpConnection(databaseId: string, connection: PostgresConnection): string | null { + if (this.tcpConnections.has(databaseId)) { + debug('TCP connection already exists for database', databaseId) + return null + } + const connectionId = Date.now().toString(36) + Math.random().toString(36).substr(2) + this.tcpConnections.set(databaseId, connection) + this.activeConnectionIds.set(databaseId, connectionId) + return connectionId + } + + removeTcpConnection(databaseId: string) { + const connectionId = this.activeConnectionIds.get(databaseId) + if (connectionId) { + this.mutexes.delete(`${databaseId}:${connectionId}`) + } + this.tcpConnections.delete(databaseId) + this.activeConnectionIds.delete(databaseId) + } + + addWebSocketConnection(databaseId: string, websocket: WebSocket): boolean { + if (this.websocketConnections.has(databaseId)) { + debug('WebSocket connection already exists for database', databaseId) + return false + } + this.websocketConnections.set(databaseId, websocket) + return true + } + + removeWebSocketConnection(databaseId: string) { + this.websocketConnections.delete(databaseId) + } + + sendMessageToWebSocket( + databaseId: string, + message: ArrayBuffer | Uint8Array, + force: boolean = false + ) { + const websocket = this.websocketConnections.get(databaseId) + + if (!websocket) { + debug('Ignoring message: No websocket connection for database', databaseId) + return + } + + const activeConnectionId = this.activeConnectionIds.get(databaseId) + + if (!activeConnectionId && !force) { + console.log({ message }) + debug('Ignoring message: No active connection for database', databaseId) + return + } + + debug('Sending message to websocket', { databaseId, message }) + websocket?.send(message) + } + + sendMessageToTcp(databaseId: string, connectionId: string, data: Buffer) { + const tcpConnection = this.tcpConnections.get(databaseId) + const activeConnectionId = this.activeConnectionIds.get(databaseId) + + if (!tcpConnection || activeConnectionId !== connectionId) { + debug('Ignoring message: No TCP connection for database or connection ID mismatch', { + databaseId, + connectionId, + }) + return + } + + debug('Sending message to TCP connection', { databaseId, data }) + tcpConnection.streamWriter?.write(data) + } + + isActiveConnection(databaseId: string, connectionId: string) { + return this.activeConnectionIds.get(databaseId) === connectionId + } + + hasWebSocketConnection(databaseId: string): boolean { + return this.websocketConnections.has(databaseId) + } + + hasTcpConnection(databaseId: string): boolean { + return this.tcpConnections.has(databaseId) + } + + getWebSocketConnection(databaseId: string): WebSocket | undefined { + return this.websocketConnections.get(databaseId) + } + + getActiveConnectionId(databaseId: string): string | undefined { + return this.activeConnectionIds.get(databaseId) + } +} diff --git a/apps/browser-proxy/src/index.ts b/apps/browser-proxy/src/index.ts index f365f31e..48678bca 100644 --- a/apps/browser-proxy/src/index.ts +++ b/apps/browser-proxy/src/index.ts @@ -15,11 +15,11 @@ import { UserConnected, UserDisconnected, } from './telemetry.ts' +import { ConnectionManager } from './connection-manager.ts' -const debug = makeDebug('browser-proxy') +const connectionManager = new ConnectionManager() -const tcpConnections = new Map() -const websocketConnections = new Map() +const debug = makeDebug('browser-proxy') const httpsServer = https.createServer({ SNICallback: (servername, callback) => { @@ -45,37 +45,43 @@ websocketServer.on('error', (error) => { debug('websocket server error', error) }) -websocketServer.on('connection', (socket, request) => { +websocketServer.on('connection', (websocket, request) => { debug('websocket connection') const host = request.headers.host if (!host) { debug('No host header present') - socket.close() + websocket.close() return } const databaseId = extractDatabaseId(host) - if (websocketConnections.has(databaseId)) { - socket.send('sorry, too many clients already') - socket.close() + if (!connectionManager.addWebSocketConnection(databaseId, websocket)) { + websocket.send('sorry, too many clients already') + websocket.close() return } - websocketConnections.set(databaseId, socket) - logEvent(new DatabaseShared({ databaseId })) - socket.on('message', (data: Buffer) => { - debug('websocket message', data.toString('hex')) - const tcpConnection = tcpConnections.get(databaseId) - tcpConnection?.streamWriter?.write(data) + websocket.on('message', (data: Buffer) => { + if (data.length === 0) { + return + } + + const activeConnectionId = connectionManager.getActiveConnectionId(databaseId) + if (!activeConnectionId) { + debug('Ignoring message: No active connection for database', databaseId) + return + } + + connectionManager.sendMessageToTcp(databaseId, activeConnectionId, data) }) - socket.on('close', () => { - websocketConnections.delete(databaseId) + websocket.on('close', () => { + connectionManager.removeWebSocketConnection(databaseId) logEvent(new DatabaseUnshared({ databaseId })) }) }) @@ -89,6 +95,7 @@ const tcpServer = net.createServer() tcpServer.on('connection', async (socket) => { let databaseId: string | undefined + let connectionId: string | null = null const connection = await fromNodeSocket(socket, { tls: getTls, @@ -103,7 +110,7 @@ tcpServer.on('connection', async (socket) => { const _databaseId = extractDatabaseId(state.tlsInfo.serverName!) - if (!websocketConnections.has(_databaseId!)) { + if (!connectionManager.hasWebSocketConnection(_databaseId!)) { throw BackendError.create({ code: 'XX000', message: 'the browser is not sharing the database', @@ -111,7 +118,7 @@ tcpServer.on('connection', async (socket) => { }) } - if (tcpConnections.has(_databaseId)) { + if (connectionManager.hasTcpConnection(_databaseId)) { throw BackendError.create({ code: '53300', message: 'sorry, too many clients already', @@ -119,16 +126,20 @@ tcpServer.on('connection', async (socket) => { }) } - // only set the databaseId after we've verified the connection databaseId = _databaseId - tcpConnections.set(databaseId!, connection) + connectionId = connectionManager.addTcpConnection(databaseId, connection) + if (!connectionId) { + debug('Rejecting new TCP connection: already exists for database', databaseId) + socket.destroy() + return + } logEvent(new UserConnected({ databaseId })) }, serverVersion() { return '16.3' }, onAuthenticated() { - const websocket = websocketConnections.get(databaseId!) + const websocket = connectionManager.getWebSocketConnection(databaseId!) if (!websocket) { throw BackendError.create({ @@ -144,34 +155,31 @@ tcpServer.on('connection', async (socket) => { websocket.send(clientIpMessage) }, onMessage(message, state) { - if (!state.isAuthenticated) { + if (!state.isAuthenticated || !databaseId || !connectionId) { return } - const websocket = websocketConnections.get(databaseId!) - - if (!websocket) { - throw BackendError.create({ - code: 'XX000', - message: 'the browser is not sharing the database', - severity: 'FATAL', - }) + if (!connectionManager.isActiveConnection(databaseId, connectionId)) { + debug('Ignoring message for inactive connection', { databaseId, connectionId }) + return new Uint8Array() } debug('tcp message', { message }) - websocket.send(message) + connectionManager.processMessage(databaseId, connectionId, message) - // return an empty buffer to indicate that the message has been handled return new Uint8Array() }, }) socket.on('close', () => { if (databaseId) { - tcpConnections.delete(databaseId) + connectionManager.removeTcpConnection(databaseId) logEvent(new UserDisconnected({ databaseId })) - const websocket = websocketConnections.get(databaseId) - websocket?.send(createStartupMessage('postgres', 'postgres', { client_ip: '' })) + connectionManager.sendMessageToWebSocket( + databaseId, + createStartupMessage('postgres', 'postgres', { client_ip: '' }), + true + ) } }) }) diff --git a/apps/browser-proxy/test/schema.prisma b/apps/browser-proxy/test/schema.prisma new file mode 100644 index 00000000..dffbbcf8 --- /dev/null +++ b/apps/browser-proxy/test/schema.prisma @@ -0,0 +1,59 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model comments { + id BigInt @id @default(autoincrement()) + post_id BigInt? + user_id BigInt? + content String + created_at DateTime? @default(now()) @db.Timestamptz(6) + posts posts? @relation(fields: [post_id], references: [id], onDelete: NoAction, onUpdate: NoAction) + users users? @relation(fields: [user_id], references: [id], onDelete: NoAction, onUpdate: NoAction) +} + +model friendships { + id BigInt @id @default(autoincrement()) + user_id BigInt? + friend_id BigInt? + created_at DateTime? @default(now()) @db.Timestamptz(6) + users_friendships_friend_idTousers users? @relation("friendships_friend_idTousers", fields: [friend_id], references: [id], onDelete: NoAction, onUpdate: NoAction) + users_friendships_user_idTousers users? @relation("friendships_user_idTousers", fields: [user_id], references: [id], onDelete: NoAction, onUpdate: NoAction) +} + +model likes { + id BigInt @id @default(autoincrement()) + post_id BigInt? + user_id BigInt? + created_at DateTime? @default(now()) @db.Timestamptz(6) + posts posts? @relation(fields: [post_id], references: [id], onDelete: NoAction, onUpdate: NoAction) + users users? @relation(fields: [user_id], references: [id], onDelete: NoAction, onUpdate: NoAction) +} + +model posts { + id BigInt @id @default(autoincrement()) + user_id BigInt? + content String + created_at DateTime? @default(now()) @db.Timestamptz(6) + comments comments[] + likes likes[] + users users? @relation(fields: [user_id], references: [id], onDelete: NoAction, onUpdate: NoAction) +} + +model users { + id BigInt @id @default(autoincrement()) + name String + email String @unique + password String + created_at DateTime? @default(now()) @db.Timestamptz(6) + comments comments[] + friendships_friendships_friend_idTousers friendships[] @relation("friendships_friend_idTousers") + friendships_friendships_user_idTousers friendships[] @relation("friendships_user_idTousers") + likes likes[] + posts posts[] +} diff --git a/apps/postgres-new/components/app-provider.tsx b/apps/postgres-new/components/app-provider.tsx index 5ffed968..f859cd4a 100644 --- a/apps/postgres-new/components/app-provider.tsx +++ b/apps/postgres-new/components/app-provider.tsx @@ -145,6 +145,7 @@ export default function AppProvider({ children }: AppProps) { // client disconnected if (parameters.client_ip === '') { setConnectedClientIp(null) + await db.query('discard all') await dbManager.closeDbInstance(databaseId) } else { db = await dbManager.getDbInstance(databaseId) diff --git a/package-lock.json b/package-lock.json index 5b98f519..d380332d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,6 +16,7 @@ "name": "@database.build/browser-proxy", "dependencies": { "@aws-sdk/client-s3": "^3.645.0", + "async-mutex": "^0.5.0", "debug": "^4.3.7", "expiry-map": "^2.0.0", "findhit-proxywrap": "^0.3.13", @@ -27,6 +28,7 @@ "@total-typescript/tsconfig": "^1.0.4", "@types/debug": "^4.1.12", "@types/node": "^22.5.4", + "prisma": "^5.19.1", "typescript": "^5.5.4" } }, @@ -2499,6 +2501,56 @@ "node": ">=14" } }, + "node_modules/@prisma/debug": { + "version": "5.19.1", + "resolved": "https://registry.npmjs.org/@prisma/debug/-/debug-5.19.1.tgz", + "integrity": "sha512-lAG6A6QnG2AskAukIEucYJZxxcSqKsMK74ZFVfCTOM/7UiyJQi48v6TQ47d6qKG3LbMslqOvnTX25dj/qvclGg==", + "dev": true, + "license": "Apache-2.0" + }, + "node_modules/@prisma/engines": { + "version": "5.19.1", + "resolved": "https://registry.npmjs.org/@prisma/engines/-/engines-5.19.1.tgz", + "integrity": "sha512-kR/PoxZDrfUmbbXqqb8SlBBgCjvGaJYMCOe189PEYzq9rKqitQ2fvT/VJ8PDSe8tTNxhc2KzsCfCAL+Iwm/7Cg==", + "dev": true, + "hasInstallScript": true, + "license": "Apache-2.0", + "dependencies": { + "@prisma/debug": "5.19.1", + "@prisma/engines-version": "5.19.1-2.69d742ee20b815d88e17e54db4a2a7a3b30324e3", + "@prisma/fetch-engine": "5.19.1", + "@prisma/get-platform": "5.19.1" + } + }, + "node_modules/@prisma/engines-version": { + "version": "5.19.1-2.69d742ee20b815d88e17e54db4a2a7a3b30324e3", + "resolved": "https://registry.npmjs.org/@prisma/engines-version/-/engines-version-5.19.1-2.69d742ee20b815d88e17e54db4a2a7a3b30324e3.tgz", + "integrity": "sha512-xR6rt+z5LnNqTP5BBc+8+ySgf4WNMimOKXRn6xfNRDSpHvbOEmd7+qAOmzCrddEc4Cp8nFC0txU14dstjH7FXA==", + "dev": true, + "license": "Apache-2.0" + }, + "node_modules/@prisma/fetch-engine": { + "version": "5.19.1", + "resolved": "https://registry.npmjs.org/@prisma/fetch-engine/-/fetch-engine-5.19.1.tgz", + "integrity": "sha512-pCq74rtlOVJfn4pLmdJj+eI4P7w2dugOnnTXpRilP/6n5b2aZiA4ulJlE0ddCbTPkfHmOL9BfaRgA8o+1rfdHw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@prisma/debug": "5.19.1", + "@prisma/engines-version": "5.19.1-2.69d742ee20b815d88e17e54db4a2a7a3b30324e3", + "@prisma/get-platform": "5.19.1" + } + }, + "node_modules/@prisma/get-platform": { + "version": "5.19.1", + "resolved": "https://registry.npmjs.org/@prisma/get-platform/-/get-platform-5.19.1.tgz", + "integrity": "sha512-sCeoJ+7yt0UjnR+AXZL7vXlg5eNxaFOwC23h0KvW1YIXUoa7+W2ZcAUhoEQBmJTW4GrFqCuZ8YSP0mkDa4k3Zg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@prisma/debug": "5.19.1" + } + }, "node_modules/@protobufjs/aspromise": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", @@ -12382,6 +12434,26 @@ "sql-formatter": "bin/sql-formatter-cli.cjs" } }, + "node_modules/prisma": { + "version": "5.19.1", + "resolved": "https://registry.npmjs.org/prisma/-/prisma-5.19.1.tgz", + "integrity": "sha512-c5K9MiDaa+VAAyh1OiYk76PXOme9s3E992D7kvvIOhCrNsBQfy2mP2QAQtX0WNj140IgG++12kwZpYB9iIydNQ==", + "dev": true, + "hasInstallScript": true, + "license": "Apache-2.0", + "dependencies": { + "@prisma/engines": "5.19.1" + }, + "bin": { + "prisma": "build/index.js" + }, + "engines": { + "node": ">=16.13" + }, + "optionalDependencies": { + "fsevents": "2.3.3" + } + }, "node_modules/prismjs": { "version": "1.29.0", "resolved": "https://registry.npmjs.org/prismjs/-/prismjs-1.29.0.tgz",