fix(port-forwarding): close socket on unexpected payloads (#6753)

This commit is contained in:
Max Schmitt 2021-06-01 14:13:23 -07:00 committed by GitHub
parent 531d35f945
commit d79110dcc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -15,7 +15,7 @@
*/ */
import net from 'net'; import net from 'net';
import { assert } from '../utils/utils'; import { debugLogger } from '../utils/debugLogger';
import { SdkObject } from './instrumentation'; import { SdkObject } from './instrumentation';
export type SocksConnectionInfo = { export type SocksConnectionInfo = {
@ -77,8 +77,9 @@ class SocksV5ServerParser {
private _dstAddrp: number = 0; private _dstAddrp: number = 0;
private _dstPort?: number; private _dstPort?: number;
private _socket: net.Socket; private _socket: net.Socket;
private _readyResolve!: (value?: unknown) => void; private _parsingFinishedResolve!: (value?: unknown) => void;
private _ready: Promise<unknown>; private _parsingFinishedReject!: (value: Error) => void;
private _parsingFinished: Promise<unknown>;
private _info: SocksConnectionInfo; private _info: SocksConnectionInfo;
private _phase: ConnectionPhases = ConnectionPhases.VERSION; private _phase: ConnectionPhases = ConnectionPhases.VERSION;
private _authMethods?: Buffer; private _authMethods?: Buffer;
@ -88,7 +89,10 @@ class SocksV5ServerParser {
constructor(socket: net.Socket) { constructor(socket: net.Socket) {
this._socket = socket; this._socket = socket;
this._info = { srcAddr: socket.remoteAddress!, srcPort: socket.remotePort!, dstAddr: '', dstPort: 0 }; this._info = { srcAddr: socket.remoteAddress!, srcPort: socket.remotePort!, dstAddr: '', dstPort: 0 };
this._ready = new Promise(resolve => this._readyResolve = resolve); this._parsingFinished = new Promise((resolve, reject) => {
this._parsingFinishedResolve = resolve;
this._parsingFinishedReject = reject;
});
socket.on('data', this._onData.bind(this)); socket.on('data', this._onData.bind(this));
socket.on('error', () => {}); socket.on('error', () => {});
} }
@ -96,10 +100,15 @@ class SocksV5ServerParser {
const socket = this._socket; const socket = this._socket;
let i = 0; let i = 0;
const readByte = () => chunk[i++]; const readByte = () => chunk[i++];
const closeSocketOnError = () => {
socket.end();
this._parsingFinishedReject(new Error('Parsing aborted'));
};
while (i < chunk.length && this._phase !== ConnectionPhases.DONE) { while (i < chunk.length && this._phase !== ConnectionPhases.DONE) {
switch (this._phase) { switch (this._phase) {
case ConnectionPhases.VERSION: case ConnectionPhases.VERSION:
assert(readByte() === SOCKS_VERSION); if (readByte() !== SOCKS_VERSION)
return closeSocketOnError();
this._phase = ConnectionPhases.NMETHODS; this._phase = ConnectionPhases.NMETHODS;
break; break;
@ -109,16 +118,19 @@ class SocksV5ServerParser {
break; break;
case ConnectionPhases.METHODS: { case ConnectionPhases.METHODS: {
assert(this._authMethods); if (!this._authMethods)
return closeSocketOnError();
chunk.copy(this._authMethods, 0, i, i + chunk.length); chunk.copy(this._authMethods, 0, i, i + chunk.length);
assert(this._authMethods.includes(SOCKS_AUTH_METHOD.NO_AUTH)); if (!this._authMethods.includes(SOCKS_AUTH_METHOD.NO_AUTH))
return closeSocketOnError();
const left = this._authMethods.length - this._methodsp; const left = this._authMethods.length - this._methodsp;
const chunkLeft = chunk.length - i; const chunkLeft = chunk.length - i;
const minLen = (left < chunkLeft ? left : chunkLeft); const minLen = (left < chunkLeft ? left : chunkLeft);
chunk.copy(this._authMethods, this._methodsp, i, i + minLen); chunk.copy(this._authMethods, this._methodsp, i, i + minLen);
this._methodsp += minLen; this._methodsp += minLen;
i += minLen; i += minLen;
assert(this._methodsp === this._authMethods.length); if (this._methodsp !== this._authMethods.length)
return closeSocketOnError();
if (i < chunk.length) if (i < chunk.length)
this._socket.unshift(chunk.slice(i)); this._socket.unshift(chunk.slice(i));
this._authWithoutPassword(socket); this._authWithoutPassword(socket);
@ -127,9 +139,11 @@ class SocksV5ServerParser {
} }
case ConnectionPhases.REQ_CMD: case ConnectionPhases.REQ_CMD:
assert(readByte() === SOCKS_VERSION); if (readByte() !== SOCKS_VERSION)
return closeSocketOnError();
const cmd: SOCKS_CMD = readByte(); const cmd: SOCKS_CMD = readByte();
assert(cmd === SOCKS_CMD.CONNECT); if (cmd !== SOCKS_CMD.CONNECT)
return closeSocketOnError();
this._phase = ConnectionPhases.REQ_RSV; this._phase = ConnectionPhases.REQ_RSV;
break; break;
@ -141,7 +155,8 @@ class SocksV5ServerParser {
case ConnectionPhases.REQ_ATYP: case ConnectionPhases.REQ_ATYP:
this._phase = ConnectionPhases.REQ_DSTADDR; this._phase = ConnectionPhases.REQ_DSTADDR;
this._addressType = readByte(); this._addressType = readByte();
assert(this._addressType in SOCKS_ATYP); if (!(this._addressType in SOCKS_ATYP))
return closeSocketOnError();
if (this._addressType === SOCKS_ATYP.IPv4) if (this._addressType === SOCKS_ATYP.IPv4)
this._dstAddr = Buffer.alloc(4); this._dstAddr = Buffer.alloc(4);
else if (this._addressType === SOCKS_ATYP.IPv6) else if (this._addressType === SOCKS_ATYP.IPv6)
@ -151,7 +166,8 @@ class SocksV5ServerParser {
break; break;
case ConnectionPhases.REQ_DSTADDR: { case ConnectionPhases.REQ_DSTADDR: {
assert(this._dstAddr); if (!this._dstAddr)
return closeSocketOnError();
const left = this._dstAddr.length - this._dstAddrp; const left = this._dstAddr.length - this._dstAddrp;
const chunkLeft = chunk.length - i; const chunkLeft = chunk.length - i;
const minLen = (left < chunkLeft ? left : chunkLeft); const minLen = (left < chunkLeft ? left : chunkLeft);
@ -169,7 +185,8 @@ class SocksV5ServerParser {
break; break;
case ConnectionPhases.REQ_DSTPORT: case ConnectionPhases.REQ_DSTPORT:
assert(this._dstAddr); if (!this._dstAddr)
return closeSocketOnError();
if (this._dstPort === undefined) { if (this._dstPort === undefined) {
this._dstPort = readByte(); this._dstPort = readByte();
break; break;
@ -197,10 +214,10 @@ class SocksV5ServerParser {
} }
this._info.dstPort = this._dstPort; this._info.dstPort = this._dstPort;
this._phase = ConnectionPhases.DONE; this._phase = ConnectionPhases.DONE;
this._readyResolve(); this._parsingFinishedResolve();
return; return;
default: default:
assert(false); return closeSocketOnError();
} }
} }
} }
@ -210,7 +227,7 @@ class SocksV5ServerParser {
} }
async ready(): Promise<{ info: SocksConnectionInfo, forward: () => void, intercept: (parent: SdkObject) => SocksInterceptedSocketHandler }> { async ready(): Promise<{ info: SocksConnectionInfo, forward: () => void, intercept: (parent: SdkObject) => SocksInterceptedSocketHandler }> {
await this._ready; await this._parsingFinished;
return { return {
info: this._info, info: this._info,
forward: () => { forward: () => {
@ -290,8 +307,14 @@ export class SocksProxyServer {
async _handleConnection(incomingMessageHandler: IncomingProxyRequestHandler, socket: net.Socket) { async _handleConnection(incomingMessageHandler: IncomingProxyRequestHandler, socket: net.Socket) {
const parser = new SocksV5ServerParser(socket); const parser = new SocksV5ServerParser(socket);
const { info, forward, intercept } = await parser.ready(); let parsedSocket;
incomingMessageHandler(info, forward, intercept); try {
parsedSocket = await parser.ready();
} catch (error) {
debugLogger.log('proxy', `Could not parse: ${error} ${error?.stack}`);
return;
}
incomingMessageHandler(parsedSocket.info, parsedSocket.forward, parsedSocket.intercept);
} }
public close() { public close() {