diff --git a/src/browserServerImpl.ts b/src/browserServerImpl.ts index 979b248473..53bce3cfde 100644 --- a/src/browserServerImpl.ts +++ b/src/browserServerImpl.ts @@ -34,9 +34,6 @@ import { BrowserContext } from './server/browserContext'; import { CRBrowser } from './server/chromium/crBrowser'; import { CDPSessionDispatcher } from './dispatchers/cdpSessionDispatcher'; import { PageDispatcher } from './dispatchers/pageDispatcher'; -import { BrowserServerPortForwardingServer } from './server/socksSocket'; -import { SocksSocketDispatcher } from './dispatchers/socksSocketDispatcher'; -import { SocksInterceptedSocketHandler } from './server/socksServer'; export class BrowserServerLauncherImpl implements BrowserServerLauncher { private _browserName: 'chromium' | 'firefox' | 'webkit'; @@ -47,14 +44,14 @@ export class BrowserServerLauncherImpl implements BrowserServerLauncher { async launchServer(options: LaunchServerOptions = {}): Promise { const playwright = createPlaywright(); - const portForwardingServer = new BrowserServerPortForwardingServer(playwright, !!options._acceptForwardedPorts); + if (options._acceptForwardedPorts) + await playwright._enablePortForwarding(); // 1. Pre-launch the browser const browser = await playwright[this._browserName].launch(internalCallMetadata(), { ...options, ignoreDefaultArgs: Array.isArray(options.ignoreDefaultArgs) ? options.ignoreDefaultArgs : undefined, ignoreAllDefaultArgs: !!options.ignoreDefaultArgs && !Array.isArray(options.ignoreDefaultArgs), env: options.env ? envObjectToArray(options.env) : undefined, - ...portForwardingServer.browserLaunchOptions(), }, toProtocolLogger(options.logger)); // 2. Start the server @@ -62,9 +59,9 @@ export class BrowserServerLauncherImpl implements BrowserServerLauncher { path: '/' + createGuid(), allowMultipleClients: options._acceptForwardedPorts ? false : true, onClose: () => { - portForwardingServer.stop(); + playwright._disablePortForwarding(); }, - onConnect: this._onConnect.bind(this, playwright, browser, portForwardingServer), + onConnect: this._onConnect.bind(this, playwright, browser), }; const server = new PlaywrightServer(delegate); const wsEndpoint = await server.listen(options.port); @@ -83,7 +80,7 @@ export class BrowserServerLauncherImpl implements BrowserServerLauncher { return browserServer; } - private _onConnect(playwright: Playwright, browser: Browser, portForwardingServer: BrowserServerPortForwardingServer, scope: DispatcherScope, forceDisconnect: () => void) { + private async _onConnect(playwright: Playwright, browser: Browser, scope: DispatcherScope, forceDisconnect: () => void) { const selectors = new Selectors(); const selectorsDispatcher = new SelectorsDispatcher(scope, selectors); const browserDispatcher = new ConnectedBrowserDispatcher(scope, browser, selectors); @@ -91,16 +88,8 @@ export class BrowserServerLauncherImpl implements BrowserServerLauncher { // Underlying browser did close for some reason - force disconnect the client. forceDisconnect(); }); - const playwrightDispatcher = new PlaywrightDispatcher(scope, playwright, selectorsDispatcher, browserDispatcher, (ports: number[]) => { - portForwardingServer.enablePortForwarding(ports); - }); - const incomingSocksSocketHandler = (socket: SocksInterceptedSocketHandler) => { - playwrightDispatcher._dispatchEvent('incomingSocksSocket', { socket: new SocksSocketDispatcher(playwrightDispatcher, socket) }); - }; - portForwardingServer.on('incomingSocksSocket', incomingSocksSocketHandler); - + new PlaywrightDispatcher(scope, playwright, selectorsDispatcher, browserDispatcher); return () => { - portForwardingServer.off('incomingSocksSocket', incomingSocksSocketHandler); // Cleanup contexts upon disconnect. browserDispatcher.cleanupContexts().catch(e => {}); }; diff --git a/src/cli/driver.ts b/src/cli/driver.ts index 306d6b595f..351531cf7b 100644 --- a/src/cli/driver.ts +++ b/src/cli/driver.ts @@ -54,7 +54,7 @@ export function runDriver() { } export async function runServer(port: number | undefined) { - const wsEndpoint = await PlaywrightServer.startDefault(port); + const wsEndpoint = await PlaywrightServer.startDefault({port}); console.log('Listening on ' + wsEndpoint); // eslint-disable-line no-console } diff --git a/src/client/browserType.ts b/src/client/browserType.ts index 68d13d4a8a..781be4076e 100644 --- a/src/client/browserType.ts +++ b/src/client/browserType.ts @@ -192,10 +192,7 @@ export class BrowserType extends ChannelOwner SocksSocket.from(socket)); } + async _enablePortForwarding(ports: number[]) { + this._forwardPorts = ports; + await this._channel.setForwardedPorts({ports}); + } + _cleanup() { this.selectors._removeChannel(this._selectorsOwner); } diff --git a/src/dispatchers/playwrightDispatcher.ts b/src/dispatchers/playwrightDispatcher.ts index 7e7a59a502..c3285603bf 100644 --- a/src/dispatchers/playwrightDispatcher.ts +++ b/src/dispatchers/playwrightDispatcher.ts @@ -22,11 +22,11 @@ import { Dispatcher, DispatcherScope } from './dispatcher'; import { ElectronDispatcher } from './electronDispatcher'; import { SelectorsDispatcher } from './selectorsDispatcher'; import * as types from '../server/types'; -import { assert } from '../utils/utils'; +import { SocksSocketDispatcher } from './socksSocketDispatcher'; +import { SocksInterceptedSocketHandler } from '../server/socksServer'; export class PlaywrightDispatcher extends Dispatcher implements channels.PlaywrightChannel { - private _portForwardingCallback: ((ports: number[]) => void) | undefined; - constructor(scope: DispatcherScope, playwright: Playwright, customSelectors?: channels.SelectorsChannel, preLaunchedBrowser?: channels.BrowserChannel, portForwardingCallback?: (ports: number[]) => void) { + constructor(scope: DispatcherScope, playwright: Playwright, customSelectors?: channels.SelectorsChannel, preLaunchedBrowser?: channels.BrowserChannel) { const descriptors = require('../server/deviceDescriptors') as types.Devices; const deviceDescriptors = Object.entries(descriptors) .map(([name, descriptor]) => ({ name, descriptor })); @@ -40,11 +40,12 @@ export class PlaywrightDispatcher extends Dispatcher { + this._dispatchEvent('incomingSocksSocket', { socket: new SocksSocketDispatcher(this, socket) }); + }); } - async enablePortForwarding(params: channels.PlaywrightEnablePortForwardingParams): Promise { - assert(this._portForwardingCallback, 'Port forwarding is only supported when using connect()'); - this._portForwardingCallback(params.ports); + async setForwardedPorts(params: channels.PlaywrightSetForwardedPortsParams): Promise { + this._object._setForwardedPorts(params.ports); } } diff --git a/src/protocol/channels.ts b/src/protocol/channels.ts index c986b2659a..b521463530 100644 --- a/src/protocol/channels.ts +++ b/src/protocol/channels.ts @@ -180,18 +180,18 @@ export type PlaywrightInitializer = { }; export interface PlaywrightChannel extends Channel { on(event: 'incomingSocksSocket', callback: (params: PlaywrightIncomingSocksSocketEvent) => void): this; - enablePortForwarding(params: PlaywrightEnablePortForwardingParams, metadata?: Metadata): Promise; + setForwardedPorts(params: PlaywrightSetForwardedPortsParams, metadata?: Metadata): Promise; } export type PlaywrightIncomingSocksSocketEvent = { socket: SocksSocketChannel, }; -export type PlaywrightEnablePortForwardingParams = { +export type PlaywrightSetForwardedPortsParams = { ports: number[], }; -export type PlaywrightEnablePortForwardingOptions = { +export type PlaywrightSetForwardedPortsOptions = { }; -export type PlaywrightEnablePortForwardingResult = void; +export type PlaywrightSetForwardedPortsResult = void; // ----------- Selectors ----------- export type SelectorsInitializer = {}; diff --git a/src/protocol/protocol.yml b/src/protocol/protocol.yml index 55d6831794..2e2b88dab0 100644 --- a/src/protocol/protocol.yml +++ b/src/protocol/protocol.yml @@ -368,7 +368,7 @@ Playwright: commands: - enablePortForwarding: + setForwardedPorts: parameters: ports: type: array diff --git a/src/protocol/validator.ts b/src/protocol/validator.ts index 241fdeb503..20f59c0535 100644 --- a/src/protocol/validator.ts +++ b/src/protocol/validator.ts @@ -147,7 +147,7 @@ export function createScheme(tChannel: (name: string) => Validator): Scheme { })), value: tOptional(tType('SerializedValue')), }); - scheme.PlaywrightEnablePortForwardingParams = tObject({ + scheme.PlaywrightSetForwardedPortsParams = tObject({ ports: tArray(tNumber), }); scheme.SelectorsRegisterParams = tObject({ diff --git a/src/remote/playwrightClient.ts b/src/remote/playwrightClient.ts index e56d7a6227..dbca997b41 100644 --- a/src/remote/playwrightClient.ts +++ b/src/remote/playwrightClient.ts @@ -18,24 +18,43 @@ import WebSocket from 'ws'; import { Connection } from '../client/connection'; import { Playwright } from '../client/playwright'; +export type PlaywrightClientConnectOptions = { + wsEndpoint: string; + forwardPorts?: number[]; + timeout?: number +}; + export class PlaywrightClient { private _playwright: Playwright; private _ws: WebSocket; private _closePromise: Promise; - static async connect(wsEndpoint: string): Promise { + static async connect(options: PlaywrightClientConnectOptions): Promise { + const {wsEndpoint, forwardPorts, timeout = 30000} = options; const connection = new Connection(); const ws = new WebSocket(wsEndpoint); connection.onmessage = message => ws.send(JSON.stringify(message)); ws.on('message', message => connection.dispatch(JSON.parse(message.toString()))); const errorPromise = new Promise((_, reject) => ws.on('error', error => reject(error))); const closePromise = new Promise((_, reject) => ws.on('close', () => reject(new Error('Connection closed')))); - const playwright = await Promise.race([ - connection.waitForObjectWithKnownName('Playwright'), - errorPromise, - closePromise - ]); - return new PlaywrightClient(playwright as Playwright, ws); + const playwrightClientPromise = new Promise(async (resolve, reject) => { + const playwright = await connection.waitForObjectWithKnownName('Playwright') as Playwright; + if (forwardPorts) + await playwright._enablePortForwarding(forwardPorts).catch(reject); + resolve(new PlaywrightClient(playwright, ws)); + }); + let timer: NodeJS.Timeout; + try { + await Promise.race([ + playwrightClientPromise, + errorPromise, + closePromise, + new Promise((_, reject) => timer = setTimeout(reject, timeout)) + ]); + return await playwrightClientPromise; + } finally { + clearTimeout(timer!); + } } constructor(playwright: Playwright, ws: WebSocket) { diff --git a/src/remote/playwrightServer.ts b/src/remote/playwrightServer.ts index 669bbbd474..c557e1df1b 100644 --- a/src/remote/playwrightServer.ts +++ b/src/remote/playwrightServer.ts @@ -28,16 +28,21 @@ const debugLog = debug('pw:server'); export interface PlaywrightServerDelegate { path: string; allowMultipleClients: boolean; - onConnect(rootScope: DispatcherScope, forceDisconnect: () => void): () => any; + onConnect(rootScope: DispatcherScope, forceDisconnect: () => void): Promise<() => any>; onClose: () => any; } +export type PlaywrightServerOptions = { + port?: number; + acceptForwardedPorts?: boolean +}; + export class PlaywrightServer { private _wsServer: ws.Server | undefined; private _clientsCount = 0; private _delegate: PlaywrightServerDelegate; - static async startDefault(port: number = 0): Promise { + static async startDefault({port = 0, acceptForwardedPorts }: PlaywrightServerOptions): Promise { const cleanup = async () => { await gracefullyCloseAll().catch(e => {}); serverSelectors.unregisterAll(); @@ -46,9 +51,15 @@ export class PlaywrightServer { path: '/ws', allowMultipleClients: false, onClose: cleanup, - onConnect: (rootScope: DispatcherScope) => { - new PlaywrightDispatcher(rootScope, createPlaywright()); - return cleanup; + onConnect: async (rootScope: DispatcherScope) => { + const playwright = createPlaywright(); + if (acceptForwardedPorts) + await playwright._enablePortForwarding(); + new PlaywrightDispatcher(rootScope, playwright); + return () => { + cleanup(); + playwright._disablePortForwarding(); + }; }, }; const server = new PlaywrightServer(delegate); @@ -66,12 +77,12 @@ export class PlaywrightServer { server.on('error', error => debugLog(error)); const path = this._delegate.path; - const wsEndpoint = await new Promise(resolve => { + const wsEndpoint = await new Promise((resolve, reject) => { server.listen(port, () => { const address = server.address(); const wsEndpoint = typeof address === 'string' ? `${address}${path}` : `ws://127.0.0.1:${address.port}${path}`; resolve(wsEndpoint); - }); + }).on('error', reject); }); debugLog('Listening at ' + wsEndpoint); @@ -96,7 +107,7 @@ export class PlaywrightServer { const forceDisconnect = () => socket.close(); const scope = connection.rootDispatcher(); - const onDisconnect = this._delegate.onConnect(scope, forceDisconnect); + let onDisconnect = () => {}; const disconnected = () => { this._clientsCount--; // Avoid sending any more messages over closed socket. @@ -111,6 +122,7 @@ export class PlaywrightServer { debugLog('Client error ' + error); disconnected(); }); + onDisconnect = await this._delegate.onConnect(scope, forceDisconnect); }); return wsEndpoint; diff --git a/src/server/browser.ts b/src/server/browser.ts index 3ea9a42913..298de6a0a4 100644 --- a/src/server/browser.ts +++ b/src/server/browser.ts @@ -35,6 +35,7 @@ export interface BrowserProcess { export type PlaywrightOptions = { registry: registry.Registry, rootSdkObject: SdkObject, + loopbackProxyOverride?: () => string, }; export type BrowserOptions = PlaywrightOptions & { diff --git a/src/server/browserType.ts b/src/server/browserType.ts index b9f8b97429..f44f95b971 100644 --- a/src/server/browserType.ts +++ b/src/server/browserType.ts @@ -60,7 +60,7 @@ export abstract class BrowserType extends SdkObject { } async launch(metadata: CallMetadata, options: types.LaunchOptions, protocolLogger?: types.ProtocolLogger): Promise { - options = validateLaunchOptions(options); + options = validateLaunchOptions(options, this._playwrightOptions.loopbackProxyOverride?.()); const controller = new ProgressController(metadata, this); controller.setLogName('browser'); const browser = await controller.run(progress => { @@ -70,7 +70,7 @@ export abstract class BrowserType extends SdkObject { } async launchPersistentContext(metadata: CallMetadata, userDataDir: string, options: types.LaunchPersistentOptions): Promise { - options = validateLaunchOptions(options); + options = validateLaunchOptions(options, this._playwrightOptions.loopbackProxyOverride?.()); const controller = new ProgressController(metadata, this); const persistent: types.BrowserContextOptions = options; controller.setLogName('browser'); @@ -273,12 +273,14 @@ function copyTestHooks(from: object, to: object) { } } -function validateLaunchOptions(options: Options): Options { +function validateLaunchOptions(options: Options, proxyOverride?: string): Options { const { devtools = false } = options; - let { headless = !devtools, downloadsPath } = options; + let { headless = !devtools, downloadsPath, proxy } = options; if (debugMode()) headless = false; if (downloadsPath && !path.isAbsolute(downloadsPath)) downloadsPath = path.join(process.cwd(), downloadsPath); - return { ...options, devtools, headless, downloadsPath }; + if (proxyOverride) + proxy = { server: proxyOverride }; + return { ...options, devtools, headless, downloadsPath, proxy }; } diff --git a/src/server/chromium/chromium.ts b/src/server/chromium/chromium.ts index 9b6fec560f..929b0b422f 100644 --- a/src/server/chromium/chromium.ts +++ b/src/server/chromium/chromium.ts @@ -34,10 +34,6 @@ import { CallMetadata } from '../instrumentation'; import { findChromiumChannel } from './findChromiumChannel'; import http from 'http'; -type LaunchServerOptions = { - _acceptForwardedPorts?: boolean, -}; - export class Chromium extends BrowserType { private _devtools: CRDevTools | undefined; @@ -123,7 +119,7 @@ export class Chromium extends BrowserType { transport.send(message); } - _defaultArgs(options: types.LaunchOptions & LaunchServerOptions, isPersistent: boolean, userDataDir: string): string[] { + _defaultArgs(options: types.LaunchOptions, isPersistent: boolean, userDataDir: string): string[] { const { args = [], proxy } = options; const userDataDirArg = args.find(arg => arg.startsWith('--user-data-dir')); if (userDataDirArg) @@ -161,7 +157,7 @@ export class Chromium extends BrowserType { chromeArguments.push(`--proxy-server=${proxy.server}`); const proxyBypassRules = []; // https://source.chromium.org/chromium/chromium/src/+/master:net/docs/proxy.md;l=548;drc=71698e610121078e0d1a811054dcf9fd89b49578 - if (options._acceptForwardedPorts) + if (this._playwrightOptions.loopbackProxyOverride) proxyBypassRules.push('<-loopback>'); if (proxy.bypass) proxyBypassRules.push(...proxy.bypass.split(',').map(t => t.trim()).map(t => t.startsWith('.') ? '*' + t : t)); diff --git a/src/server/playwright.ts b/src/server/playwright.ts index 3b3c1929e8..0e449afabd 100644 --- a/src/server/playwright.ts +++ b/src/server/playwright.ts @@ -26,6 +26,9 @@ import { WebKit } from './webkit/webkit'; import { Registry } from '../utils/registry'; import { CallMetadata, createInstrumentation, SdkObject } from './instrumentation'; import { debugLogger } from '../utils/debugLogger'; +import { PortForwardingServer } from './socksSocket'; +import { SocksInterceptedSocketHandler } from './socksServer'; +import { assert } from '../utils/utils'; export class Playwright extends SdkObject { readonly selectors: Selectors; @@ -35,6 +38,7 @@ export class Playwright extends SdkObject { readonly firefox: Firefox; readonly webkit: WebKit; readonly options: PlaywrightOptions; + private _portForwardingServer: PortForwardingServer | undefined; constructor(isInternal: boolean) { super({ attribution: { isInternal }, instrumentation: createInstrumentation() } as any, undefined, 'Playwright'); @@ -54,6 +58,27 @@ export class Playwright extends SdkObject { this.android = new Android(new AdbBackend(), this.options); this.selectors = serverSelectors; } + + async _enablePortForwarding() { + assert(!this._portForwardingServer); + this._portForwardingServer = await PortForwardingServer.create(this); + this.options.loopbackProxyOverride = () => this._portForwardingServer!.proxyServer(); + this._portForwardingServer.on('incomingSocksSocket', (socket: SocksInterceptedSocketHandler) => { + this.emit('incomingSocksSocket', socket); + }); + } + + _disablePortForwarding() { + if (!this._portForwardingServer) + return; + this._portForwardingServer.stop(); + } + + _setForwardedPorts(ports: number[]) { + if (!this._portForwardingServer) + throw new Error(`Port forwarding needs to be enabled when launching the server via BrowserType.launchServer.`); + this._portForwardingServer.setForwardedPorts(ports); + } } export function createPlaywright(isInternal = false) { diff --git a/src/server/socksServer.ts b/src/server/socksServer.ts index 2de08a8d0c..d4f424b727 100644 --- a/src/server/socksServer.ts +++ b/src/server/socksServer.ts @@ -301,8 +301,8 @@ export class SocksProxyServer { this.server = net.createServer(this._handleConnection.bind(this, incomingMessageHandler)); } - public listen(port: number, host?: string) { - this.server.listen(port, host); + public async listen(port: number, host?: string) { + await new Promise(resolve => this.server.listen(port, host, resolve)); } async _handleConnection(incomingMessageHandler: IncomingProxyRequestHandler, socket: net.Socket) { diff --git a/src/server/socksSocket.ts b/src/server/socksSocket.ts index 8f3b09e57a..15d0a5da13 100644 --- a/src/server/socksSocket.ts +++ b/src/server/socksSocket.ts @@ -21,39 +21,31 @@ import { SdkObject } from './instrumentation'; import { debugLogger } from '../utils/debugLogger'; import { isLocalIpAddress } from '../utils/utils'; import { SocksProxyServer, SocksConnectionInfo, SocksInterceptedSocketHandler } from './socksServer'; -import { LaunchOptions } from './types'; -export class BrowserServerPortForwardingServer extends EventEmitter { - enabled: boolean; +export class PortForwardingServer extends EventEmitter { private _forwardPorts: number[] = []; private _parent: SdkObject; private _server: SocksProxyServer; - constructor(parent: SdkObject, enabled: boolean) { + constructor(parent: SdkObject) { super(); this.setMaxListeners(0); - this.enabled = enabled; this._parent = parent; this._server = new SocksProxyServer(this._handler.bind(this)); - if (enabled) { - this._server.listen(0); - debugLogger.log('proxy', `initialized server on port ${this._port()})`); - } + } + + static async create(parent: SdkObject) { + const server = new PortForwardingServer(parent); + await server._server.listen(0); + debugLogger.log('proxy', `starting server on port ${server._port()})`); + return server; } private _port(): number { - if (!this.enabled) - return 0; return (this._server.server.address() as net.AddressInfo).port; } - public browserLaunchOptions(): LaunchOptions | undefined { - if (!this.enabled) - return; - return { - proxy: { - server: `socks5://127.0.0.1:${this._port()}` - } - }; + public proxyServer() { + return `socks5://127.0.0.1:${this._port()}`; } private _handler(info: SocksConnectionInfo, forward: () => void, intercept: (parent: SdkObject) => SocksInterceptedSocketHandler): void { @@ -67,16 +59,12 @@ export class BrowserServerPortForwardingServer extends EventEmitter { this.emit('incomingSocksSocket', socket); } - public enablePortForwarding(ports: number[]): void { - if (!this.enabled) - throw new Error(`Port forwarding needs to be enabled when launching the server via BrowserType.launchServer.`); + public setForwardedPorts(ports: number[]): void { debugLogger.log('proxy', `enable port forwarding on ports: ${ports}`); this._forwardPorts = ports; } public stop(): void { - if (!this.enabled) - return; debugLogger.log('proxy', 'stopping server'); this._server.close(); }