Browse Source

Pull fix of recording from websockify.

Pull websockify 7f487fdbd.

The reocrd parameter will turn on recording of all messages sent
to and from the client. The record parameter is a file prefix. The
full file-name will be the prefix with an extension '.HANDLER_ID'
based on the handler ID.
Joel Martin 14 years ago
parent
commit
8c305c60ad
2 changed files with 288 additions and 223 deletions
  1. 278 192
      utils/websocket.py
  2. 10 31
      utils/websockify

+ 278 - 192
utils/websocket.py

@@ -2,7 +2,7 @@
 
 '''
 Python WebSocket library with support for "wss://" encryption.
-Copyright 2010 Joel Martin
+Copyright 2011 Joel Martin
 Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
 
 Supports following protocol versions:
@@ -16,23 +16,48 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
 
 '''
 
-import sys, socket, ssl, struct, traceback, select
-import os, resource, errno, signal # daemonizing
-from SimpleHTTPServer import SimpleHTTPRequestHandler
-from cStringIO import StringIO
+import os, sys, time, errno, signal, socket, struct, traceback, select
+from cgi import parse_qsl
 from base64 import b64encode, b64decode
-try:
+
+# Imports that vary by python version
+if sys.hexversion > 0x3000000:
+    # python >= 3.0
+    from io import StringIO
+    from http.server import SimpleHTTPRequestHandler
+    from urllib.parse import urlsplit
+    b2s = lambda buf: buf.decode('latin_1')
+    s2b = lambda s: s.encode('latin_1')
+else:
+    # python 2.X
+    from cStringIO import StringIO
+    from SimpleHTTPServer import SimpleHTTPRequestHandler
+    from urlparse import urlsplit
+    # No-ops
+    b2s = lambda buf: buf
+    s2b = lambda s: s
+
+if sys.hexversion >= 0x2060000:
+    # python >= 2.6
+    from multiprocessing import Process
     from hashlib import md5, sha1
-except:
-    # Support python 2.4
+else:
+    # python < 2.6
+    Process = None
     from md5 import md5
     from sha import sha as sha1
-try:
-    import numpy, ctypes
-except:
-    numpy = ctypes = None
-from urlparse import urlsplit
-from cgi import parse_qsl
+
+# Degraded functionality if these imports are missing
+for mod, sup in [('numpy', 'HyBi protocol'),
+        ('ctypes', 'HyBi protocol'), ('ssl', 'TLS/SSL/wss'),
+        ('resource', 'daemonizing')]:
+    try:
+        globals()[mod] = __import__(mod)
+    except ImportError:
+        globals()[mod] = None
+        print("WARNING: no '%s' module, %s support disabled" % (
+            mod, sup))
+
 
 class WebSocketServer(object):
     """
@@ -72,6 +97,7 @@ Sec-WebSocket-Accept: %s\r
         self.listen_port = listen_port
         self.ssl_only    = ssl_only
         self.daemon      = daemon
+        self.handler_id  = 1
 
         # Make paths settings absolute
         self.cert = os.path.abspath(cert)
@@ -86,22 +112,32 @@ Sec-WebSocket-Accept: %s\r
         if self.web:
             os.chdir(self.web)
 
-        self.handler_id  = 1
-
-        print "WebSocket server settings:"
-        print "  - Listen on %s:%s" % (
-                self.listen_host, self.listen_port)
-        print "  - Flash security policy server"
+        # Sanity checks
+        if ssl and self.ssl_only:
+            raise Exception("No 'ssl' module and SSL-only specified")
+        if self.daemon and not resource:
+            raise Exception("Module 'resource' required to daemonize")
+
+        # Show configuration
+        print("WebSocket server settings:")
+        print("  - Listen on %s:%s" % (
+                self.listen_host, self.listen_port))
+        print("  - Flash security policy server")
         if self.web:
-            print "  - Web server"
-        if os.path.exists(self.cert):
-            print "  - SSL/TLS support"
-            if self.ssl_only:
-                print "  - Deny non-SSL/TLS connections"
+            print("  - Web server")
+        if ssl:
+            if os.path.exists(self.cert):
+                print("  - SSL/TLS support")
+                if self.ssl_only:
+                    print("  - Deny non-SSL/TLS connections")
+            else:
+                print("  - No SSL/TLS support (no cert file)")
         else:
-            print "  - No SSL/TLS support (no cert file)"
+            print("  - No SSL/TLS support (no 'ssl' module)")
         if self.daemon:
-            print "  - Backgrounding (daemon)"
+            print("  - Backgrounding (daemon)")
+        if self.record:
+            print("  - Recording to '%s.*'" % self.record)
 
     #
     # WebSocketServer static methods
@@ -133,7 +169,8 @@ Sec-WebSocket-Accept: %s\r
             try:
                 if fd != keepfd:
                     os.close(fd)
-            except OSError, exc:
+            except OSError:
+                _, exc, _ = sys.exc_info()
                 if exc.errno != errno.EBADF: raise
 
         # Redirect I/O to /dev/null
@@ -164,9 +201,9 @@ Sec-WebSocket-Accept: %s\r
         elif payload_len >= 65536:
             header = struct.pack('>BBQ', b1, 127, payload_len)
 
-        #print "Encoded: %s" % repr(header + buf)
+        #print("Encoded: %s" % repr(header + buf))
 
-        return header + buf
+        return header + buf, len(header), 0
 
     @staticmethod
     def decode_hybi(buf, base64=False):
@@ -175,6 +212,7 @@ Sec-WebSocket-Accept: %s\r
             {'fin'          : 0_or_1,
              'opcode'       : number,
              'mask'         : 32_bit_number,
+             'hlen'         : header_bytes_number,
              'length'       : payload_bytes_number,
              'payload'      : decoded_buffer,
              'left'         : bytes_left_number,
@@ -182,122 +220,103 @@ Sec-WebSocket-Accept: %s\r
              'close_reason' : string}
         """
 
-        ret = {'fin'          : 0,
-               'opcode'       : 0,
-               'mask'         : 0,
-               'length'       : 0,
-               'payload'      : None,
-               'left'         : 0,
-               'close_code'   : None,
-               'close_reason' : None}
+        f = {'fin'          : 0,
+             'opcode'       : 0,
+             'mask'         : 0,
+             'hlen'         : 2,
+             'length'       : 0,
+             'payload'      : None,
+             'left'         : 0,
+             'close_code'   : None,
+             'close_reason' : None}
 
         blen = len(buf)
-        ret['left'] = blen
-        header_len = 2
+        f['left'] = blen
 
-        if blen < header_len:
-            return ret # Incomplete frame header
+        if blen < f['hlen']:
+            return f # Incomplete frame header
 
         b1, b2 = struct.unpack_from(">BB", buf)
-        ret['opcode'] = b1 & 0x0f
-        ret['fin'] = (b1 & 0x80) >> 7
+        f['opcode'] = b1 & 0x0f
+        f['fin'] = (b1 & 0x80) >> 7
         has_mask = (b2 & 0x80) >> 7
 
-        ret['length'] = b2 & 0x7f
+        f['length'] = b2 & 0x7f
 
-        if ret['length'] == 126:
-            header_len = 4
-            if blen < header_len:
-                return ret # Incomplete frame header
-            (ret['length'],) = struct.unpack_from('>xxH', buf)
-        elif ret['length'] == 127:
-            header_len = 10
-            if blen < header_len:
-                return ret # Incomplete frame header
-            (ret['length'],) = struct.unpack_from('>xxQ', buf)
+        if f['length'] == 126:
+            f['hlen'] = 4
+            if blen < f['hlen']:
+                return f # Incomplete frame header
+            (f['length'],) = struct.unpack_from('>xxH', buf)
+        elif f['length'] == 127:
+            f['hlen'] = 10
+            if blen < f['hlen']:
+                return f # Incomplete frame header
+            (f['length'],) = struct.unpack_from('>xxQ', buf)
 
-        full_len = header_len + has_mask * 4 + ret['length']
+        full_len = f['hlen'] + has_mask * 4 + f['length']
 
         if blen < full_len: # Incomplete frame
-            return ret # Incomplete frame header
+            return f # Incomplete frame header
 
         # Number of bytes that are part of the next frame(s)
-        ret['left'] = blen - full_len
+        f['left'] = blen - full_len
 
         # Process 1 frame
         if has_mask:
             # unmask payload
-            ret['mask'] = buf[header_len:header_len+4]
+            f['mask'] = buf[f['hlen']:f['hlen']+4]
             b = c = ''
-            if ret['length'] >= 4:
+            if f['length'] >= 4:
                 mask = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
-                        offset=header_len, count=1)
+                        offset=f['hlen'], count=1)
                 data = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
-                        offset=header_len + 4, count=int(ret['length'] / 4))
+                        offset=f['hlen'] + 4, count=int(f['length'] / 4))
                 #b = numpy.bitwise_xor(data, mask).data
                 b = numpy.bitwise_xor(data, mask).tostring()
 
-            if ret['length'] % 4:
-                print "Partial unmask"
+            if f['length'] % 4:
+                print("Partial unmask")
                 mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
-                        offset=header_len, count=(ret['length'] % 4))
+                        offset=f['hlen'], count=(f['length'] % 4))
                 data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
-                        offset=full_len - (ret['length'] % 4),
-                        count=(ret['length'] % 4))
+                        offset=full_len - (f['length'] % 4),
+                        count=(f['length'] % 4))
                 c = numpy.bitwise_xor(data, mask).tostring()
-            ret['payload'] = b + c
+            f['payload'] = b + c
         else:
-            print "Unmasked frame:", repr(buf)
-            ret['payload'] = buf[(header_len + has_mask * 4):full_len]
+            print("Unmasked frame: %s" % repr(buf))
+            f['payload'] = buf[(f['hlen'] + has_mask * 4):full_len]
 
-        if base64 and ret['opcode'] in [1, 2]:
+        if base64 and f['opcode'] in [1, 2]:
             try:
-                ret['payload'] = b64decode(ret['payload'])
+                f['payload'] = b64decode(f['payload'])
             except:
-                print "Exception while b64decoding buffer:", repr(buf)
+                print("Exception while b64decoding buffer: %s" %
+                        repr(buf))
                 raise
 
-        if ret['opcode'] == 0x08:
-            if ret['length'] >= 2:
-                ret['close_code'] = struct.unpack_from(
-                        ">H", ret['payload'])
-            if ret['length'] > 3:
-                ret['close_reason'] = ret['payload'][2:]
+        if f['opcode'] == 0x08:
+            if f['length'] >= 2:
+                f['close_code'] = struct.unpack_from(">H", f['payload'])
+            if f['length'] > 3:
+                f['close_reason'] = f['payload'][2:]
 
-        return ret
+        return f
 
     @staticmethod
     def encode_hixie(buf):
-        return "\x00" + b64encode(buf) + "\xff"
+        return s2b("\x00" + b2s(b64encode(buf)) + "\xff"), 1, 1
 
     @staticmethod
     def decode_hixie(buf):
-        end = buf.find('\xff')
+        end = buf.find(s2b('\xff'))
         return {'payload': b64decode(buf[1:end]),
+                'hlen': 1,
+                'length': end - 1,
                 'left': len(buf) - (end + 1)}
 
 
-    @staticmethod
-    def parse_handshake(handshake):
-        """ Parse fields from client WebSockets handshake. """
-        ret = {}
-        req_lines = handshake.split("\r\n")
-        if not req_lines[0].startswith("GET "):
-            raise Exception("Invalid handshake: no GET request line")
-        ret['path'] = req_lines[0].split(" ")[1]
-        for line in req_lines[1:]:
-            if line == "": break
-            try:
-                var, val = line.split(": ")
-            except:
-                raise Exception("Invalid handshake header: %s" % line)
-            ret[var] = val
-
-        if req_lines[-2] == "":
-            ret['key3'] = req_lines[-1]
-
-        return ret
-
     @staticmethod
     def gen_md5(keys):
         """ Generate hash value for WebSockets hixie-76. """
@@ -309,7 +328,8 @@ Sec-WebSocket-Accept: %s\r
         num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
         num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
 
-        return md5(struct.pack('>II8s', num1, num2, key3)).digest()
+        return b2s(md5(struct.pack('>II8s',
+            int(num1), int(num2), key3)).digest())
 
     #
     # WebSocketServer logging/output functions
@@ -324,7 +344,7 @@ Sec-WebSocket-Accept: %s\r
     def msg(self, msg):
         """ Output message with handler_id prefix. """
         if not self.daemon:
-            print "% 3d: %s" % (self.handler_id, msg)
+            print("% 3d: %s" % (self.handler_id, msg))
 
     def vmsg(self, msg):
         """ Same as msg() but only if verbose. """
@@ -342,17 +362,27 @@ Sec-WebSocket-Accept: %s\r
         than 0, then the caller should call again when the socket is
         ready. """
 
+        tdelta = int(time.time()*1000) - self.start_time
+
         if bufs:
             for buf in bufs:
                 if self.version.startswith("hybi"):
                     if self.base64:
-                        self.send_parts.append(self.encode_hybi(buf,
-                            opcode=1, base64=True))
+                        encbuf, lenhead, lentail = self.encode_hybi(
+                                buf, opcode=1, base64=True)
                     else:
-                        self.send_parts.append(self.encode_hybi(buf,
-                            opcode=2, base64=False))
+                        encbuf, lenhead, lentail = self.encode_hybi(
+                                buf, opcode=2, base64=False)
+
                 else:
-                    self.send_parts.append(self.encode_hixie(buf))
+                    encbuf, lenhead, lentail = self.encode_hixie(buf)
+
+                if self.rec:
+                    self.rec.write("%s,\n" %
+                            repr("{%s{" % tdelta
+                                + encbuf[lenhead:-lentail]))
+
+                self.send_parts.append(encbuf)
 
         while self.send_parts:
             # Send pending frames
@@ -377,6 +407,7 @@ Sec-WebSocket-Accept: %s\r
 
         closed = False
         bufs = []
+        tdelta = int(time.time()*1000) - self.start_time
 
         buf = self.client.recv(self.buffer_size)
         if len(buf) == 0:
@@ -392,7 +423,7 @@ Sec-WebSocket-Accept: %s\r
             if self.version.startswith("hybi"):
 
                 frame = self.decode_hybi(buf, base64=self.base64)
-                #print "Received buf: %s, frame: %s" % (repr(buf), frame)
+                #print("Received buf: %s, frame: %s" % (repr(buf), frame))
 
                 if frame['payload'] == None:
                     # Incomplete/partial frame
@@ -416,7 +447,7 @@ Sec-WebSocket-Accept: %s\r
                     buf = buf[2:]
                     continue # No-op
 
-                elif buf.count('\xff') == 0:
+                elif buf.count(s2b('\xff')) == 0:
                     # Partial frame
                     self.traffic("}.")
                     self.recv_part = buf
@@ -426,6 +457,13 @@ Sec-WebSocket-Accept: %s\r
 
             self.traffic("}")
 
+            if self.rec:
+                start = frame['hlen']
+                end = frame['hlen'] + frame['length']
+                self.rec.write("%s,\n" %
+                        repr("}%s}" % tdelta + buf[start:end]))
+
+
             bufs.append(frame['payload'])
 
             if frame['left']:
@@ -439,7 +477,7 @@ Sec-WebSocket-Accept: %s\r
         """ Send a WebSocket orderly close frame. """
 
         if self.version.startswith("hybi"):
-            msg = ''
+            msg = s2b('')
             if code != None:
                 msg = struct.pack(">H%ds" % (len(reason)), code)
 
@@ -447,7 +485,7 @@ Sec-WebSocket-Accept: %s\r
             self.client.send(buf)
 
         elif self.version == "hixie-76":
-            buf = self.encode_hixie('\xff\x00')
+            buf = s2b('\xff\x00')
             self.client.send(buf)
 
         # No orderly close for 75
@@ -483,14 +521,16 @@ Sec-WebSocket-Accept: %s\r
         if handshake == "":
             raise self.EClose("ignoring empty handshake")
 
-        elif handshake.startswith("<policy-file-request/>"):
+        elif handshake.startswith(s2b("<policy-file-request/>")):
             # Answer Flash policy request
             handshake = sock.recv(1024)
-            sock.send(self.policy_response)
+            sock.send(s2b(self.policy_response))
             raise self.EClose("Sending flash policy response")
 
         elif handshake[0] in ("\x16", "\x80"):
             # SSL wrap the connection
+            if not ssl:
+                raise self.EClose("SSL connection but no 'ssl' module")
             if not os.path.exists(self.cert):
                 raise self.EClose("SSL connection but '%s' not found"
                                   % self.cert)
@@ -500,7 +540,8 @@ Sec-WebSocket-Accept: %s\r
                         server_side=True,
                         certfile=self.cert,
                         keyfile=self.key)
-            except ssl.SSLError, x:
+            except ssl.SSLError:
+                _, x, _ = sys.exc_info()
                 if x.args[0] == ssl.SSL_ERROR_EOF:
                     raise self.EClose("")
                 else:
@@ -517,29 +558,21 @@ Sec-WebSocket-Accept: %s\r
             scheme = "ws"
             stype = "Plain non-SSL (ws://)"
 
-        # Now get the data from the socket
-        handshake = retsock.recv(4096)
-
-        if len(handshake) == 0:
-            raise self.EClose("Client closed during handshake")
-
-        # Check for and handle normal web requests
-        if (handshake.startswith('GET ') and
-                handshake.find('Upgrade: WebSocket\r\n') == -1 and
-                handshake.find('Upgrade: websocket\r\n') == -1):
-            if not self.web:
-                raise self.EClose("Normal web request received but disallowed")
-            sh = SplitHTTPHandler(handshake, retsock, address)
-            if sh.last_code < 200 or sh.last_code >= 300:
-                raise self.EClose(sh.last_message)
-            elif self.verbose:
-                raise self.EClose(sh.last_message)
-            else:
-                raise self.EClose("")
+        wsh = WSRequestHandler(retsock, address, not self.web)
+        if wsh.last_code == 101:
+            # Continue on to handle WebSocket upgrade
+            pass
+        elif wsh.last_code == 405:
+            raise self.EClose("Normal web request received but disallowed")
+        elif wsh.last_code < 200 or wsh.last_code >= 300:
+            raise self.EClose(wsh.last_message)
+        elif self.verbose:
+            raise self.EClose(wsh.last_message)
+        else:
+            raise self.EClose("")
 
-        #self.msg("handshake: " + repr(handshake))
-        # Parse client WebSockets handshake
-        h = self.headers = self.parse_handshake(handshake)
+        h = self.headers = wsh.headers
+        path = self.path = wsh.path
 
         prot = 'WebSocket-Protocol'
         protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
@@ -548,8 +581,8 @@ Sec-WebSocket-Accept: %s\r
         if ver:
             # HyBi/IETF version of the protocol
 
-            if not numpy or not ctypes:
-                self.EClose("Python numpy and ctypes modules required for HyBi-07 or greater")
+            if sys.hexversion < 0x2060000 or not numpy:
+                raise self.EClose("Python >= 2.6 and numpy module is required for HyBi-07 or greater")
 
             if ver == '7':
                 self.version = "hybi-07"
@@ -567,7 +600,7 @@ Sec-WebSocket-Accept: %s\r
                 raise self.EClose("Client must support 'binary' or 'base64' protocol")
 
             # Generate the hash value for the accept header
-            accept = b64encode(sha1(key + self.GUID).digest())
+            accept = b64encode(sha1(s2b(key + self.GUID)).digest())
 
             response = self.server_handshake_hybi % accept
             if self.base64:
@@ -592,7 +625,7 @@ Sec-WebSocket-Accept: %s\r
             self.base64 = True
 
             response = self.server_handshake_hixie % (pre,
-                    h['Origin'], pre, scheme, h['Host'], h['path'])
+                    h['Origin'], pre, scheme, h['Host'], path)
 
             if 'base64' in protocols:
                 response += "%sWebSocket-Protocol: base64\r\n" % pre
@@ -606,7 +639,7 @@ Sec-WebSocket-Accept: %s\r
 
         # Send server WebSockets handshake response
         #self.msg("sending response [%s]" % response)
-        retsock.send(response)
+        retsock.send(s2b(response))
 
         # Return the WebSockets socket which may be SSL wrapped
         return retsock
@@ -624,9 +657,8 @@ Sec-WebSocket-Accept: %s\r
         #self.vmsg("Running poll()")
         pass
 
-    def top_SIGCHLD(self, sig, stack):
-        # Reap zombies after calling child SIGCHLD handler
-        self.do_SIGCHLD(sig, stack)
+    def fallback_SIGCHLD(self, sig, stack):
+        # Reap zombies when using os.fork() (python 2.4)
         self.vmsg("Got SIGCHLD, reaping zombies")
         try:
             result = os.waitpid(-1, os.WNOHANG)
@@ -636,14 +668,52 @@ Sec-WebSocket-Accept: %s\r
         except (OSError):
             pass
 
-    def do_SIGCHLD(self, sig, stack):
-        pass
-
     def do_SIGINT(self, sig, stack):
         self.msg("Got SIGINT, exiting")
         sys.exit(0)
 
-    def new_client(self, client):
+    def top_new_client(self, startsock, address):
+        """ Do something with a WebSockets client connection. """
+        # Initialize per client settings
+        self.send_parts = []
+        self.recv_part  = None
+        self.base64     = False
+        self.rec        = None
+        self.start_time = int(time.time()*1000)
+
+        # handler process
+        try:
+            try:
+                self.client = self.do_handshake(startsock, address)
+
+                if self.record:
+                    # Record raw frame data as JavaScript array
+                    fname = "%s.%s" % (self.record,
+                                        self.handler_id)
+                    self.msg("opening record file: %s" % fname)
+                    self.rec = open(fname, 'w+')
+                    self.rec.write("var VNC_frame_data = [\n")
+
+                self.new_client()
+            except self.EClose:
+                _, exc, _ = sys.exc_info()
+                # Connection was not a WebSockets connection
+                if exc.args[0]:
+                    self.msg("%s: %s" % (address[0], exc.args[0]))
+            except Exception:
+                _, exc, _ = sys.exc_info()
+                self.msg("handler exception: %s" % str(exc))
+                if self.verbose:
+                    self.msg(traceback.format_exc())
+        finally:
+            if self.rec:
+                self.rec.write("'EOF']\n")
+                self.rec.close()
+
+            if self.client and self.client != startsock:
+                self.client.close()
+
+    def new_client(self):
         """ Do something with a WebSockets client connection. """
         raise("WebSocketServer.new_client() must be overloaded")
 
@@ -665,9 +735,11 @@ Sec-WebSocket-Accept: %s\r
 
         self.started()  # Some things need to happen after daemonizing
 
-        # Reep zombies
-        signal.signal(signal.SIGCHLD, self.top_SIGCHLD)
+        # Allow override of SIGINT
         signal.signal(signal.SIGINT, self.do_SIGINT)
+        if not Process:
+            # os.fork() (python 2.4) child reaper
+            signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD)
 
         while True:
             try:
@@ -679,14 +751,17 @@ Sec-WebSocket-Accept: %s\r
                     try:
                         self.poll()
 
-                        ready = select.select([lsock], [], [], 1)[0];
+                        ready = select.select([lsock], [], [], 1)[0]
                         if lsock in ready:
                             startsock, address = lsock.accept()
                         else:
                             continue
-                    except Exception, exc:
+                    except Exception:
+                        _, exc, _ = sys.exc_info()
                         if hasattr(exc, 'errno'):
                             err = exc.errno
+                        elif hasattr(exc, 'args'):
+                            err = exc.args[0]
                         else:
                             err = exc[0]
                         if err == errno.EINTR:
@@ -695,55 +770,67 @@ Sec-WebSocket-Accept: %s\r
                         else:
                             raise
 
-                    self.vmsg('%s: forking handler' % address[0])
-                    pid = os.fork()
-
-                    if pid == 0:
-                        # Initialize per client settings
-                        self.send_parts = []
-                        self.recv_part  = None
-                        self.base64     = False
-                        # handler process
-                        self.client = self.do_handshake(
-                                startsock, address)
-                        self.new_client()
+                    if Process:
+                        self.vmsg('%s: new handler Process' % address[0])
+                        p = Process(target=self.top_new_client,
+                                args=(startsock, address))
+                        p.start()
+                        # child will not return
                     else:
-                        # parent process
-                        self.handler_id += 1
-
-                except self.EClose, exc:
-                    # Connection was not a WebSockets connection
-                    if exc.args[0]:
-                        self.msg("%s: %s" % (address[0], exc.args[0]))
-                except KeyboardInterrupt, exc:
+                        # python 2.4
+                        self.vmsg('%s: forking handler' % address[0])
+                        pid = os.fork()
+                        if pid == 0:
+                            # child handler process
+                            self.top_new_client(startsock, address)
+                            break  # child process exits
+
+                    # parent process
+                    self.handler_id += 1
+
+                except KeyboardInterrupt:
+                    _, exc, _ = sys.exc_info()
+                    print("In KeyboardInterrupt")
                     pass
-                except Exception, exc:
+                except SystemExit:
+                    _, exc, _ = sys.exc_info()
+                    print("In SystemExit")
+                    break
+                except Exception:
+                    _, exc, _ = sys.exc_info()
                     self.msg("handler exception: %s" % str(exc))
                     if self.verbose:
                         self.msg(traceback.format_exc())
 
             finally:
-                if self.client and self.client != startsock:
-                    self.client.close()
                 if startsock:
                     startsock.close()
 
-            if pid == 0:
-                break # Child process exits
-
 
-# HTTP handler with request from a string and response to a socket
-class SplitHTTPHandler(SimpleHTTPRequestHandler):
-    def __init__(self, req, resp, addr):
-        # Save the response socket
-        self.response = resp
+# HTTP handler with WebSocket upgrade support
+class WSRequestHandler(SimpleHTTPRequestHandler):
+    def __init__(self, req, addr, only_upgrade=False):
+        self.only_upgrade = only_upgrade # only allow upgrades
         SimpleHTTPRequestHandler.__init__(self, req, addr, object())
 
-    def setup(self):
-        self.connection = self.response
-        # Duck type request string to file object
-        self.rfile = StringIO(self.request)
-        self.wfile = self.connection.makefile('wb', self.wbufsize)
+    def do_GET(self):
+        if (self.headers.get('upgrade') and
+                self.headers.get('upgrade').lower() == 'websocket'):
+
+            if (self.headers.get('sec-websocket-key1') or
+                    self.headers.get('websocket-key1')):
+                # For Hixie-76 read out the key hash
+                self.headers.__setitem__('key3', self.rfile.read(8))
+
+            # Just indicate that an WebSocket upgrade is needed
+            self.last_code = 101
+            self.last_message = "101 Switching Protocols"
+        elif self.only_upgrade:
+            # Normal web request responses are disabled
+            self.last_code = 405
+            self.last_message = "405 Method Not Allowed"
+        else:
+            SimpleHTTPRequestHandler.do_GET(self)
 
     def send_response(self, code, message=None):
         # Save the status code
@@ -754,4 +841,3 @@ class SplitHTTPHandler(SimpleHTTPRequestHandler):
         # Save instead of printing
         self.last_message = f % args
 
-

+ 10 - 31
utils/websockify

@@ -2,7 +2,7 @@
 
 '''
 A WebSocket to TCP socket proxy with support for "wss://" encryption.
-Copyright 2010 Joel Martin
+Copyright 2011 Joel Martin
 Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
 
 You can make a cert/key with openssl using:
@@ -74,7 +74,7 @@ Traffic Legend:
         WebSocketServer.__init__(self, *args, **kwargs)
 
     def run_wrap_cmd(self):
-        print "Starting '%s'" % " ".join(self.wrap_cmd)
+        print("Starting '%s'" % " ".join(self.wrap_cmd))
         self.wrap_times.append(time.time())
         self.wrap_times.pop(0)
         self.cmd = subprocess.Popen(
@@ -88,14 +88,14 @@ Traffic Legend:
         # Need to call wrapped command after daemonization so we can
         # know when the wrapped command exits
         if self.wrap_cmd:
-            print "  - proxying from %s:%s to '%s' (port %s)\n" % (
+            print("  - proxying from %s:%s to '%s' (port %s)\n" % (
                     self.listen_host, self.listen_port,
-                    " ".join(self.wrap_cmd), self.target_port)
+                    " ".join(self.wrap_cmd), self.target_port))
             self.run_wrap_cmd()
         else:
-            print "  - proxying from %s:%s to %s:%s\n" % (
+            print("  - proxying from %s:%s to %s:%s\n" % (
                     self.listen_host, self.listen_port,
-                    self.target_host, self.target_port)
+                    self.target_host, self.target_port))
 
     def poll(self):
         # If we are wrapping a command, check it's status
@@ -118,7 +118,7 @@ Traffic Legend:
                 if (now - avg) < 10:
                     # 3 times in the last 10 seconds
                     if self.spawn_message:
-                        print "Command respawning too fast"
+                        print("Command respawning too fast")
                         self.spawn_message = False
                 else:
                     self.run_wrap_cmd()
@@ -138,15 +138,6 @@ Traffic Legend:
         Called after a new WebSocket connection has been established.
         """
 
-        self.rec = None
-        if self.record:
-            # Record raw frame data as a JavaScript compatible file
-            fname = "%s.%s" % (self.record,
-                                self.handler_id)
-            self.msg("opening record file: %s" % fname)
-            self.rec = open(fname, 'w+')
-            self.rec.write("var VNC_frame_data = [\n")
-
         # Connect to the target
         self.msg("connecting to: %s:%s" % (
                  self.target_host, self.target_port))
@@ -154,19 +145,17 @@ Traffic Legend:
         tsock.connect((self.target_host, self.target_port))
 
         if self.verbose and not self.daemon:
-            print self.traffic_legend
+            print(self.traffic_legend)
 
         # Start proxying
         try:
             self.do_proxy(tsock)
         except:
             if tsock:
+                tsock.shutdown(socket.SHUT_RDWR)
                 tsock.close()
                 self.vmsg("%s:%s: Target closed" %(
                     self.target_host, self.target_port))
-            if self.rec:
-                self.rec.write("'EOF']\n")
-                self.rec.close()
             raise
 
     def do_proxy(self, target):
@@ -177,11 +166,9 @@ Traffic Legend:
         c_pend = 0
         tqueue = []
         rlist = [self.client, target]
-        tstart = int(time.time()*1000)
 
         while True:
             wlist = []
-            tdelta = int(time.time()*1000) - tstart
 
             if tqueue: wlist.append(target)
             if cqueue or c_pend: wlist.append(self.client)
@@ -212,11 +199,8 @@ Traffic Legend:
             if self.client in outs:
                 # Send queued target data to the client
                 c_pend = self.send_frames(cqueue)
-                cqueue = []
 
-                #if self.rec:
-                #    self.rec.write("%s,\n" %
-                #            repr("{%s{" % tdelta + dat[1:-1]))
+                cqueue = []
 
 
             if self.client in ins:
@@ -224,11 +208,6 @@ Traffic Legend:
                 bufs, closed = self.recv_frames()
                 tqueue.extend(bufs)
 
-                #if self.rec:
-                #    for b in bufs:
-                #        self.rec.write(
-                #                repr("}%s}%s" % (tdelta, b)) + ",\n")
-
                 if closed:
                     # TODO: What about blocking on client socket?
                     self.send_close()