From d79110dcc172fb3d97956c48bfd8415145d546cf Mon Sep 17 00:00:00 2001 From: Max Schmitt Date: Tue, 1 Jun 2021 14:13:23 -0700 Subject: [PATCH] fix(port-forwarding): close socket on unexpected payloads (#6753) --- src/server/socksServer.ts | 59 +++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/src/server/socksServer.ts b/src/server/socksServer.ts index 8ef128b25c..2de08a8d0c 100644 --- a/src/server/socksServer.ts +++ b/src/server/socksServer.ts @@ -15,7 +15,7 @@ */ import net from 'net'; -import { assert } from '../utils/utils'; +import { debugLogger } from '../utils/debugLogger'; import { SdkObject } from './instrumentation'; export type SocksConnectionInfo = { @@ -77,8 +77,9 @@ class SocksV5ServerParser { private _dstAddrp: number = 0; private _dstPort?: number; private _socket: net.Socket; - private _readyResolve!: (value?: unknown) => void; - private _ready: Promise; + private _parsingFinishedResolve!: (value?: unknown) => void; + private _parsingFinishedReject!: (value: Error) => void; + private _parsingFinished: Promise; private _info: SocksConnectionInfo; private _phase: ConnectionPhases = ConnectionPhases.VERSION; private _authMethods?: Buffer; @@ -88,7 +89,10 @@ class SocksV5ServerParser { constructor(socket: net.Socket) { this._socket = socket; this._info = { srcAddr: socket.remoteAddress!, srcPort: socket.remotePort!, dstAddr: '', dstPort: 0 }; - this._ready = new Promise(resolve => this._readyResolve = resolve); + this._parsingFinished = new Promise((resolve, reject) => { + this._parsingFinishedResolve = resolve; + this._parsingFinishedReject = reject; + }); socket.on('data', this._onData.bind(this)); socket.on('error', () => {}); } @@ -96,10 +100,15 @@ class SocksV5ServerParser { const socket = this._socket; let i = 0; const readByte = () => chunk[i++]; + const closeSocketOnError = () => { + socket.end(); + this._parsingFinishedReject(new Error('Parsing aborted')); + }; while (i < chunk.length && this._phase !== ConnectionPhases.DONE) { switch (this._phase) { case ConnectionPhases.VERSION: - assert(readByte() === SOCKS_VERSION); + if (readByte() !== SOCKS_VERSION) + return closeSocketOnError(); this._phase = ConnectionPhases.NMETHODS; break; @@ -109,16 +118,19 @@ class SocksV5ServerParser { break; case ConnectionPhases.METHODS: { - assert(this._authMethods); + if (!this._authMethods) + return closeSocketOnError(); chunk.copy(this._authMethods, 0, i, i + chunk.length); - assert(this._authMethods.includes(SOCKS_AUTH_METHOD.NO_AUTH)); + if (!this._authMethods.includes(SOCKS_AUTH_METHOD.NO_AUTH)) + return closeSocketOnError(); const left = this._authMethods.length - this._methodsp; const chunkLeft = chunk.length - i; const minLen = (left < chunkLeft ? left : chunkLeft); chunk.copy(this._authMethods, this._methodsp, i, i + minLen); this._methodsp += minLen; i += minLen; - assert(this._methodsp === this._authMethods.length); + if (this._methodsp !== this._authMethods.length) + return closeSocketOnError(); if (i < chunk.length) this._socket.unshift(chunk.slice(i)); this._authWithoutPassword(socket); @@ -127,9 +139,11 @@ class SocksV5ServerParser { } case ConnectionPhases.REQ_CMD: - assert(readByte() === SOCKS_VERSION); + if (readByte() !== SOCKS_VERSION) + return closeSocketOnError(); const cmd: SOCKS_CMD = readByte(); - assert(cmd === SOCKS_CMD.CONNECT); + if (cmd !== SOCKS_CMD.CONNECT) + return closeSocketOnError(); this._phase = ConnectionPhases.REQ_RSV; break; @@ -141,7 +155,8 @@ class SocksV5ServerParser { case ConnectionPhases.REQ_ATYP: this._phase = ConnectionPhases.REQ_DSTADDR; this._addressType = readByte(); - assert(this._addressType in SOCKS_ATYP); + if (!(this._addressType in SOCKS_ATYP)) + return closeSocketOnError(); if (this._addressType === SOCKS_ATYP.IPv4) this._dstAddr = Buffer.alloc(4); else if (this._addressType === SOCKS_ATYP.IPv6) @@ -151,7 +166,8 @@ class SocksV5ServerParser { break; case ConnectionPhases.REQ_DSTADDR: { - assert(this._dstAddr); + if (!this._dstAddr) + return closeSocketOnError(); const left = this._dstAddr.length - this._dstAddrp; const chunkLeft = chunk.length - i; const minLen = (left < chunkLeft ? left : chunkLeft); @@ -169,7 +185,8 @@ class SocksV5ServerParser { break; case ConnectionPhases.REQ_DSTPORT: - assert(this._dstAddr); + if (!this._dstAddr) + return closeSocketOnError(); if (this._dstPort === undefined) { this._dstPort = readByte(); break; @@ -197,10 +214,10 @@ class SocksV5ServerParser { } this._info.dstPort = this._dstPort; this._phase = ConnectionPhases.DONE; - this._readyResolve(); + this._parsingFinishedResolve(); return; default: - assert(false); + return closeSocketOnError(); } } } @@ -210,7 +227,7 @@ class SocksV5ServerParser { } async ready(): Promise<{ info: SocksConnectionInfo, forward: () => void, intercept: (parent: SdkObject) => SocksInterceptedSocketHandler }> { - await this._ready; + await this._parsingFinished; return { info: this._info, forward: () => { @@ -290,8 +307,14 @@ export class SocksProxyServer { async _handleConnection(incomingMessageHandler: IncomingProxyRequestHandler, socket: net.Socket) { const parser = new SocksV5ServerParser(socket); - const { info, forward, intercept } = await parser.ready(); - incomingMessageHandler(info, forward, intercept); + let parsedSocket; + try { + parsedSocket = await parser.ready(); + } catch (error) { + debugLogger.log('proxy', `Could not parse: ${error} ${error?.stack}`); + return; + } + incomingMessageHandler(parsedSocket.info, parsedSocket.forward, parsedSocket.intercept); } public close() {