import type { IncomingMessage } from "node:http"; import os from "node:os"; import type { WebSocket } from "ws"; import { upsertPresence } from "../../../infra/system-presence.js"; import { rawDataToString } from "../../../infra/ws.js"; import type { createSubsystemLogger } from "../../../logging.js"; import { isGatewayCliClient, isWebchatClient } from "../../../utils/message-channel.js"; import type { ResolvedGatewayAuth } from "../../auth.js"; import { authorizeGatewayConnect } from "../../auth.js"; import { isLoopbackAddress } from "../../net.js"; import { type ConnectParams, ErrorCodes, type ErrorShape, errorShape, formatValidationErrors, PROTOCOL_VERSION, type RequestFrame, validateConnectParams, validateRequestFrame, } from "../../protocol/index.js"; import { MAX_BUFFERED_BYTES, MAX_PAYLOAD_BYTES, TICK_INTERVAL_MS } from "../../server-constants.js"; import type { GatewayRequestContext, GatewayRequestHandlers } from "../../server-methods/types.js"; import { handleGatewayRequest } from "../../server-methods.js"; import { formatError } from "../../server-utils.js"; import { formatForLog, logWs } from "../../ws-log.js"; import { truncateCloseReason } from "../close-reason.js"; import { buildGatewaySnapshot, getHealthCache, getHealthVersion, incrementPresenceVersion, refreshGatewayHealthSnapshot, } from "../health-state.js"; import type { GatewayWsClient } from "../ws-types.js"; type SubsystemLogger = ReturnType; export function attachGatewayWsMessageHandler(params: { socket: WebSocket; upgradeReq: IncomingMessage; connId: string; remoteAddr?: string; forwardedFor?: string; requestHost?: string; requestOrigin?: string; requestUserAgent?: string; canvasHostUrl?: string; resolvedAuth: ResolvedGatewayAuth; gatewayMethods: string[]; events: string[]; extraHandlers: GatewayRequestHandlers; buildRequestContext: () => GatewayRequestContext; send: (obj: unknown) => void; close: (code?: number, reason?: string) => void; isClosed: () => boolean; clearHandshakeTimer: () => void; getClient: () => GatewayWsClient | null; setClient: (next: GatewayWsClient) => void; setHandshakeState: (state: "pending" | "connected" | "failed") => void; setCloseCause: (cause: string, meta?: Record) => void; setLastFrameMeta: (meta: { type?: string; method?: string; id?: string }) => void; logGateway: SubsystemLogger; logHealth: SubsystemLogger; logWsControl: SubsystemLogger; }) { const { socket, upgradeReq, connId, remoteAddr, forwardedFor, requestHost, requestOrigin, requestUserAgent, canvasHostUrl, resolvedAuth, gatewayMethods, events, extraHandlers, buildRequestContext, send, close, isClosed, clearHandshakeTimer, getClient, setClient, setHandshakeState, setCloseCause, setLastFrameMeta, logGateway, logHealth, logWsControl, } = params; const isWebchatConnect = (p: ConnectParams | null | undefined) => isWebchatClient(p?.client); socket.on("message", async (data) => { if (isClosed()) return; const text = rawDataToString(data); try { const parsed = JSON.parse(text); const frameType = parsed && typeof parsed === "object" && "type" in parsed ? typeof (parsed as { type?: unknown }).type === "string" ? String((parsed as { type?: unknown }).type) : undefined : undefined; const frameMethod = parsed && typeof parsed === "object" && "method" in parsed ? typeof (parsed as { method?: unknown }).method === "string" ? String((parsed as { method?: unknown }).method) : undefined : undefined; const frameId = parsed && typeof parsed === "object" && "id" in parsed ? typeof (parsed as { id?: unknown }).id === "string" ? String((parsed as { id?: unknown }).id) : undefined : undefined; if (frameType || frameMethod || frameId) { setLastFrameMeta({ type: frameType, method: frameMethod, id: frameId }); } const client = getClient(); if (!client) { // Handshake must be a normal request: // { type:"req", method:"connect", params: ConnectParams }. const isRequestFrame = validateRequestFrame(parsed); if ( !isRequestFrame || (parsed as RequestFrame).method !== "connect" || !validateConnectParams((parsed as RequestFrame).params) ) { const handshakeError = isRequestFrame ? (parsed as RequestFrame).method === "connect" ? `invalid connect params: ${formatValidationErrors(validateConnectParams.errors)}` : "invalid handshake: first request must be connect" : "invalid request frame"; setHandshakeState("failed"); setCloseCause("invalid-handshake", { frameType, frameMethod, frameId, handshakeError, }); if (isRequestFrame) { const req = parsed as RequestFrame; send({ type: "res", id: req.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, handshakeError), }); } else { logWsControl.warn( `invalid handshake conn=${connId} remote=${remoteAddr ?? "?"} fwd=${forwardedFor ?? "n/a"} origin=${requestOrigin ?? "n/a"} host=${requestHost ?? "n/a"} ua=${requestUserAgent ?? "n/a"}`, ); } const closeReason = truncateCloseReason(handshakeError || "invalid handshake"); if (isRequestFrame) { queueMicrotask(() => close(1008, closeReason)); } else { close(1008, closeReason); } return; } const frame = parsed as RequestFrame; const connectParams = frame.params as ConnectParams; const clientLabel = connectParams.client.displayName ?? connectParams.client.id; // protocol negotiation const { minProtocol, maxProtocol } = connectParams; if (maxProtocol < PROTOCOL_VERSION || minProtocol > PROTOCOL_VERSION) { setHandshakeState("failed"); logWsControl.warn( `protocol mismatch conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, ); setCloseCause("protocol-mismatch", { minProtocol, maxProtocol, expectedProtocol: PROTOCOL_VERSION, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, mode: connectParams.client.mode, version: connectParams.client.version, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "protocol mismatch", { details: { expectedProtocol: PROTOCOL_VERSION }, }), }); close(1002, "protocol mismatch"); return; } const authResult = await authorizeGatewayConnect({ auth: resolvedAuth, connectAuth: connectParams.auth, req: upgradeReq, }); if (!authResult.ok) { setHandshakeState("failed"); logWsControl.warn( `unauthorized conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, ); const authProvided = connectParams.auth?.token ? "token" : connectParams.auth?.password ? "password" : "none"; setCloseCause("unauthorized", { authMode: resolvedAuth.mode, authProvided, authReason: authResult.reason, allowTailscale: resolvedAuth.allowTailscale, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, mode: connectParams.client.mode, version: connectParams.client.version, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "unauthorized"), }); close(1008, "unauthorized"); return; } const authMethod = authResult.method ?? "none"; const shouldTrackPresence = !isGatewayCliClient(connectParams.client); const clientId = connectParams.client.id; const instanceId = connectParams.client.instanceId; const presenceKey = shouldTrackPresence ? (instanceId ?? connId) : undefined; logWs("in", "connect", { connId, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, version: connectParams.client.version, mode: connectParams.client.mode, clientId, platform: connectParams.client.platform, auth: authMethod, }); if (isWebchatConnect(connectParams)) { logWsControl.info( `webchat connected conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, ); } if (presenceKey) { upsertPresence(presenceKey, { host: connectParams.client.displayName ?? connectParams.client.id ?? os.hostname(), ip: isLoopbackAddress(remoteAddr) ? undefined : remoteAddr, version: connectParams.client.version, platform: connectParams.client.platform, deviceFamily: connectParams.client.deviceFamily, modelIdentifier: connectParams.client.modelIdentifier, mode: connectParams.client.mode, instanceId, reason: "connect", }); incrementPresenceVersion(); } const snapshot = buildGatewaySnapshot(); const cachedHealth = getHealthCache(); if (cachedHealth) { snapshot.health = cachedHealth; snapshot.stateVersion.health = getHealthVersion(); } const helloOk = { type: "hello-ok", protocol: PROTOCOL_VERSION, server: { version: process.env.CLAWDBOT_VERSION ?? process.env.npm_package_version ?? "dev", commit: process.env.GIT_COMMIT, host: os.hostname(), connId, }, features: { methods: gatewayMethods, events }, snapshot, canvasHostUrl, policy: { maxPayload: MAX_PAYLOAD_BYTES, maxBufferedBytes: MAX_BUFFERED_BYTES, tickIntervalMs: TICK_INTERVAL_MS, }, }; clearHandshakeTimer(); const nextClient: GatewayWsClient = { socket, connect: connectParams, connId, presenceKey, }; setClient(nextClient); setHandshakeState("connected"); logWs("out", "hello-ok", { connId, methods: gatewayMethods.length, events: events.length, presence: snapshot.presence.length, stateVersion: snapshot.stateVersion.presence, }); send({ type: "res", id: frame.id, ok: true, payload: helloOk }); void refreshGatewayHealthSnapshot({ probe: true }).catch((err) => logHealth.error(`post-connect health refresh failed: ${formatError(err)}`), ); return; } // After handshake, accept only req frames if (!validateRequestFrame(parsed)) { send({ type: "res", id: (parsed as { id?: unknown })?.id ?? "invalid", ok: false, error: errorShape( ErrorCodes.INVALID_REQUEST, `invalid request frame: ${formatValidationErrors(validateRequestFrame.errors)}`, ), }); return; } const req = parsed as RequestFrame; logWs("in", "req", { connId, id: req.id, method: req.method }); const respond = ( ok: boolean, payload?: unknown, error?: ErrorShape, meta?: Record, ) => { send({ type: "res", id: req.id, ok, payload, error }); logWs("out", "res", { connId, id: req.id, ok, method: req.method, errorCode: error?.code, errorMessage: error?.message, ...meta, }); }; void (async () => { await handleGatewayRequest({ req, respond, client, isWebchatConnect, extraHandlers, context: buildRequestContext(), }); })().catch((err) => { logGateway.error(`request handler failed: ${formatForLog(err)}`); respond(false, undefined, errorShape(ErrorCodes.UNAVAILABLE, formatForLog(err))); }); } catch (err) { logGateway.error(`parse/handle error: ${String(err)}`); logWs("out", "parse-error", { connId, error: formatForLog(err) }); if (!getClient()) { close(); } } }); }