Explorar o código

Pull fix of recording from websockify.

Pull websockify 7f487fdbd.

The reocrd parameter will turn on recording of all messages sent
to and from the client. The record parameter is a file prefix. The
full file-name will be the prefix with an extension '.HANDLER_ID'
based on the handler ID.
Joel Martin %!s(int64=14) %!d(string=hai) anos
pai
achega
8c305c60ad
Modificáronse 2 ficheiros con 288 adicións e 223 borrados
  1. 278 192
      utils/websocket.py
  2. 10 31
      utils/websockify

+ 278 - 192
utils/websocket.py

@@ -2,7 +2,7 @@
 
 
 '''
 '''
 Python WebSocket library with support for "wss://" encryption.
 Python WebSocket library with support for "wss://" encryption.
-Copyright 2010 Joel Martin
+Copyright 2011 Joel Martin
 Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
 Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
 
 
 Supports following protocol versions:
 Supports following protocol versions:
@@ -16,23 +16,48 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
 
 
 '''
 '''
 
 
-import sys, socket, ssl, struct, traceback, select
-import os, resource, errno, signal # daemonizing
-from SimpleHTTPServer import SimpleHTTPRequestHandler
-from cStringIO import StringIO
+import os, sys, time, errno, signal, socket, struct, traceback, select
+from cgi import parse_qsl
 from base64 import b64encode, b64decode
 from base64 import b64encode, b64decode
-try:
+
+# Imports that vary by python version
+if sys.hexversion > 0x3000000:
+    # python >= 3.0
+    from io import StringIO
+    from http.server import SimpleHTTPRequestHandler
+    from urllib.parse import urlsplit
+    b2s = lambda buf: buf.decode('latin_1')
+    s2b = lambda s: s.encode('latin_1')
+else:
+    # python 2.X
+    from cStringIO import StringIO
+    from SimpleHTTPServer import SimpleHTTPRequestHandler
+    from urlparse import urlsplit
+    # No-ops
+    b2s = lambda buf: buf
+    s2b = lambda s: s
+
+if sys.hexversion >= 0x2060000:
+    # python >= 2.6
+    from multiprocessing import Process
     from hashlib import md5, sha1
     from hashlib import md5, sha1
-except:
-    # Support python 2.4
+else:
+    # python < 2.6
+    Process = None
     from md5 import md5
     from md5 import md5
     from sha import sha as sha1
     from sha import sha as sha1
-try:
-    import numpy, ctypes
-except:
-    numpy = ctypes = None
-from urlparse import urlsplit
-from cgi import parse_qsl
+
+# Degraded functionality if these imports are missing
+for mod, sup in [('numpy', 'HyBi protocol'),
+        ('ctypes', 'HyBi protocol'), ('ssl', 'TLS/SSL/wss'),
+        ('resource', 'daemonizing')]:
+    try:
+        globals()[mod] = __import__(mod)
+    except ImportError:
+        globals()[mod] = None
+        print("WARNING: no '%s' module, %s support disabled" % (
+            mod, sup))
+
 
 
 class WebSocketServer(object):
 class WebSocketServer(object):
     """
     """
@@ -72,6 +97,7 @@ Sec-WebSocket-Accept: %s\r
         self.listen_port = listen_port
         self.listen_port = listen_port
         self.ssl_only    = ssl_only
         self.ssl_only    = ssl_only
         self.daemon      = daemon
         self.daemon      = daemon
+        self.handler_id  = 1
 
 
         # Make paths settings absolute
         # Make paths settings absolute
         self.cert = os.path.abspath(cert)
         self.cert = os.path.abspath(cert)
@@ -86,22 +112,32 @@ Sec-WebSocket-Accept: %s\r
         if self.web:
         if self.web:
             os.chdir(self.web)
             os.chdir(self.web)
 
 
-        self.handler_id  = 1
-
-        print "WebSocket server settings:"
-        print "  - Listen on %s:%s" % (
-                self.listen_host, self.listen_port)
-        print "  - Flash security policy server"
+        # Sanity checks
+        if ssl and self.ssl_only:
+            raise Exception("No 'ssl' module and SSL-only specified")
+        if self.daemon and not resource:
+            raise Exception("Module 'resource' required to daemonize")
+
+        # Show configuration
+        print("WebSocket server settings:")
+        print("  - Listen on %s:%s" % (
+                self.listen_host, self.listen_port))
+        print("  - Flash security policy server")
         if self.web:
         if self.web:
-            print "  - Web server"
-        if os.path.exists(self.cert):
-            print "  - SSL/TLS support"
-            if self.ssl_only:
-                print "  - Deny non-SSL/TLS connections"
+            print("  - Web server")
+        if ssl:
+            if os.path.exists(self.cert):
+                print("  - SSL/TLS support")
+                if self.ssl_only:
+                    print("  - Deny non-SSL/TLS connections")
+            else:
+                print("  - No SSL/TLS support (no cert file)")
         else:
         else:
-            print "  - No SSL/TLS support (no cert file)"
+            print("  - No SSL/TLS support (no 'ssl' module)")
         if self.daemon:
         if self.daemon:
-            print "  - Backgrounding (daemon)"
+            print("  - Backgrounding (daemon)")
+        if self.record:
+            print("  - Recording to '%s.*'" % self.record)
 
 
     #
     #
     # WebSocketServer static methods
     # WebSocketServer static methods
@@ -133,7 +169,8 @@ Sec-WebSocket-Accept: %s\r
             try:
             try:
                 if fd != keepfd:
                 if fd != keepfd:
                     os.close(fd)
                     os.close(fd)
-            except OSError, exc:
+            except OSError:
+                _, exc, _ = sys.exc_info()
                 if exc.errno != errno.EBADF: raise
                 if exc.errno != errno.EBADF: raise
 
 
         # Redirect I/O to /dev/null
         # Redirect I/O to /dev/null
@@ -164,9 +201,9 @@ Sec-WebSocket-Accept: %s\r
         elif payload_len >= 65536:
         elif payload_len >= 65536:
             header = struct.pack('>BBQ', b1, 127, payload_len)
             header = struct.pack('>BBQ', b1, 127, payload_len)
 
 
-        #print "Encoded: %s" % repr(header + buf)
+        #print("Encoded: %s" % repr(header + buf))
 
 
-        return header + buf
+        return header + buf, len(header), 0
 
 
     @staticmethod
     @staticmethod
     def decode_hybi(buf, base64=False):
     def decode_hybi(buf, base64=False):
@@ -175,6 +212,7 @@ Sec-WebSocket-Accept: %s\r
             {'fin'          : 0_or_1,
             {'fin'          : 0_or_1,
              'opcode'       : number,
              'opcode'       : number,
              'mask'         : 32_bit_number,
              'mask'         : 32_bit_number,
+             'hlen'         : header_bytes_number,
              'length'       : payload_bytes_number,
              'length'       : payload_bytes_number,
              'payload'      : decoded_buffer,
              'payload'      : decoded_buffer,
              'left'         : bytes_left_number,
              'left'         : bytes_left_number,
@@ -182,122 +220,103 @@ Sec-WebSocket-Accept: %s\r
              'close_reason' : string}
              'close_reason' : string}
         """
         """
 
 
-        ret = {'fin'          : 0,
-               'opcode'       : 0,
-               'mask'         : 0,
-               'length'       : 0,
-               'payload'      : None,
-               'left'         : 0,
-               'close_code'   : None,
-               'close_reason' : None}
+        f = {'fin'          : 0,
+             'opcode'       : 0,
+             'mask'         : 0,
+             'hlen'         : 2,
+             'length'       : 0,
+             'payload'      : None,
+             'left'         : 0,
+             'close_code'   : None,
+             'close_reason' : None}
 
 
         blen = len(buf)
         blen = len(buf)
-        ret['left'] = blen
-        header_len = 2
+        f['left'] = blen
 
 
-        if blen < header_len:
-            return ret # Incomplete frame header
+        if blen < f['hlen']:
+            return f # Incomplete frame header
 
 
         b1, b2 = struct.unpack_from(">BB", buf)
         b1, b2 = struct.unpack_from(">BB", buf)
-        ret['opcode'] = b1 & 0x0f
-        ret['fin'] = (b1 & 0x80) >> 7
+        f['opcode'] = b1 & 0x0f
+        f['fin'] = (b1 & 0x80) >> 7
         has_mask = (b2 & 0x80) >> 7
         has_mask = (b2 & 0x80) >> 7
 
 
-        ret['length'] = b2 & 0x7f
+        f['length'] = b2 & 0x7f
 
 
-        if ret['length'] == 126:
-            header_len = 4
-            if blen < header_len:
-                return ret # Incomplete frame header
-            (ret['length'],) = struct.unpack_from('>xxH', buf)
-        elif ret['length'] == 127:
-            header_len = 10
-            if blen < header_len:
-                return ret # Incomplete frame header
-            (ret['length'],) = struct.unpack_from('>xxQ', buf)
+        if f['length'] == 126:
+            f['hlen'] = 4
+            if blen < f['hlen']:
+                return f # Incomplete frame header
+            (f['length'],) = struct.unpack_from('>xxH', buf)
+        elif f['length'] == 127:
+            f['hlen'] = 10
+            if blen < f['hlen']:
+                return f # Incomplete frame header
+            (f['length'],) = struct.unpack_from('>xxQ', buf)
 
 
-        full_len = header_len + has_mask * 4 + ret['length']
+        full_len = f['hlen'] + has_mask * 4 + f['length']
 
 
         if blen < full_len: # Incomplete frame
         if blen < full_len: # Incomplete frame
-            return ret # Incomplete frame header
+            return f # Incomplete frame header
 
 
         # Number of bytes that are part of the next frame(s)
         # Number of bytes that are part of the next frame(s)
-        ret['left'] = blen - full_len
+        f['left'] = blen - full_len
 
 
         # Process 1 frame
         # Process 1 frame
         if has_mask:
         if has_mask:
             # unmask payload
             # unmask payload
-            ret['mask'] = buf[header_len:header_len+4]
+            f['mask'] = buf[f['hlen']:f['hlen']+4]
             b = c = ''
             b = c = ''
-            if ret['length'] >= 4:
+            if f['length'] >= 4:
                 mask = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
                 mask = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
-                        offset=header_len, count=1)
+                        offset=f['hlen'], count=1)
                 data = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
                 data = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
-                        offset=header_len + 4, count=int(ret['length'] / 4))
+                        offset=f['hlen'] + 4, count=int(f['length'] / 4))
                 #b = numpy.bitwise_xor(data, mask).data
                 #b = numpy.bitwise_xor(data, mask).data
                 b = numpy.bitwise_xor(data, mask).tostring()
                 b = numpy.bitwise_xor(data, mask).tostring()
 
 
-            if ret['length'] % 4:
-                print "Partial unmask"
+            if f['length'] % 4:
+                print("Partial unmask")
                 mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
                 mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
-                        offset=header_len, count=(ret['length'] % 4))
+                        offset=f['hlen'], count=(f['length'] % 4))
                 data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
                 data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
-                        offset=full_len - (ret['length'] % 4),
-                        count=(ret['length'] % 4))
+                        offset=full_len - (f['length'] % 4),
+                        count=(f['length'] % 4))
                 c = numpy.bitwise_xor(data, mask).tostring()
                 c = numpy.bitwise_xor(data, mask).tostring()
-            ret['payload'] = b + c
+            f['payload'] = b + c
         else:
         else:
-            print "Unmasked frame:", repr(buf)
-            ret['payload'] = buf[(header_len + has_mask * 4):full_len]
+            print("Unmasked frame: %s" % repr(buf))
+            f['payload'] = buf[(f['hlen'] + has_mask * 4):full_len]
 
 
-        if base64 and ret['opcode'] in [1, 2]:
+        if base64 and f['opcode'] in [1, 2]:
             try:
             try:
-                ret['payload'] = b64decode(ret['payload'])
+                f['payload'] = b64decode(f['payload'])
             except:
             except:
-                print "Exception while b64decoding buffer:", repr(buf)
+                print("Exception while b64decoding buffer: %s" %
+                        repr(buf))
                 raise
                 raise
 
 
-        if ret['opcode'] == 0x08:
-            if ret['length'] >= 2:
-                ret['close_code'] = struct.unpack_from(
-                        ">H", ret['payload'])
-            if ret['length'] > 3:
-                ret['close_reason'] = ret['payload'][2:]
+        if f['opcode'] == 0x08:
+            if f['length'] >= 2:
+                f['close_code'] = struct.unpack_from(">H", f['payload'])
+            if f['length'] > 3:
+                f['close_reason'] = f['payload'][2:]
 
 
-        return ret
+        return f
 
 
     @staticmethod
     @staticmethod
     def encode_hixie(buf):
     def encode_hixie(buf):
-        return "\x00" + b64encode(buf) + "\xff"
+        return s2b("\x00" + b2s(b64encode(buf)) + "\xff"), 1, 1
 
 
     @staticmethod
     @staticmethod
     def decode_hixie(buf):
     def decode_hixie(buf):
-        end = buf.find('\xff')
+        end = buf.find(s2b('\xff'))
         return {'payload': b64decode(buf[1:end]),
         return {'payload': b64decode(buf[1:end]),
+                'hlen': 1,
+                'length': end - 1,
                 'left': len(buf) - (end + 1)}
                 'left': len(buf) - (end + 1)}
 
 
 
 
-    @staticmethod
-    def parse_handshake(handshake):
-        """ Parse fields from client WebSockets handshake. """
-        ret = {}
-        req_lines = handshake.split("\r\n")
-        if not req_lines[0].startswith("GET "):
-            raise Exception("Invalid handshake: no GET request line")
-        ret['path'] = req_lines[0].split(" ")[1]
-        for line in req_lines[1:]:
-            if line == "": break
-            try:
-                var, val = line.split(": ")
-            except:
-                raise Exception("Invalid handshake header: %s" % line)
-            ret[var] = val
-
-        if req_lines[-2] == "":
-            ret['key3'] = req_lines[-1]
-
-        return ret
-
     @staticmethod
     @staticmethod
     def gen_md5(keys):
     def gen_md5(keys):
         """ Generate hash value for WebSockets hixie-76. """
         """ Generate hash value for WebSockets hixie-76. """
@@ -309,7 +328,8 @@ Sec-WebSocket-Accept: %s\r
         num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
         num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
         num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
         num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
 
 
-        return md5(struct.pack('>II8s', num1, num2, key3)).digest()
+        return b2s(md5(struct.pack('>II8s',
+            int(num1), int(num2), key3)).digest())
 
 
     #
     #
     # WebSocketServer logging/output functions
     # WebSocketServer logging/output functions
@@ -324,7 +344,7 @@ Sec-WebSocket-Accept: %s\r
     def msg(self, msg):
     def msg(self, msg):
         """ Output message with handler_id prefix. """
         """ Output message with handler_id prefix. """
         if not self.daemon:
         if not self.daemon:
-            print "% 3d: %s" % (self.handler_id, msg)
+            print("% 3d: %s" % (self.handler_id, msg))
 
 
     def vmsg(self, msg):
     def vmsg(self, msg):
         """ Same as msg() but only if verbose. """
         """ Same as msg() but only if verbose. """
@@ -342,17 +362,27 @@ Sec-WebSocket-Accept: %s\r
         than 0, then the caller should call again when the socket is
         than 0, then the caller should call again when the socket is
         ready. """
         ready. """
 
 
+        tdelta = int(time.time()*1000) - self.start_time
+
         if bufs:
         if bufs:
             for buf in bufs:
             for buf in bufs:
                 if self.version.startswith("hybi"):
                 if self.version.startswith("hybi"):
                     if self.base64:
                     if self.base64:
-                        self.send_parts.append(self.encode_hybi(buf,
-                            opcode=1, base64=True))
+                        encbuf, lenhead, lentail = self.encode_hybi(
+                                buf, opcode=1, base64=True)
                     else:
                     else:
-                        self.send_parts.append(self.encode_hybi(buf,
-                            opcode=2, base64=False))
+                        encbuf, lenhead, lentail = self.encode_hybi(
+                                buf, opcode=2, base64=False)
+
                 else:
                 else:
-                    self.send_parts.append(self.encode_hixie(buf))
+                    encbuf, lenhead, lentail = self.encode_hixie(buf)
+
+                if self.rec:
+                    self.rec.write("%s,\n" %
+                            repr("{%s{" % tdelta
+                                + encbuf[lenhead:-lentail]))
+
+                self.send_parts.append(encbuf)
 
 
         while self.send_parts:
         while self.send_parts:
             # Send pending frames
             # Send pending frames
@@ -377,6 +407,7 @@ Sec-WebSocket-Accept: %s\r
 
 
         closed = False
         closed = False
         bufs = []
         bufs = []
+        tdelta = int(time.time()*1000) - self.start_time
 
 
         buf = self.client.recv(self.buffer_size)
         buf = self.client.recv(self.buffer_size)
         if len(buf) == 0:
         if len(buf) == 0:
@@ -392,7 +423,7 @@ Sec-WebSocket-Accept: %s\r
             if self.version.startswith("hybi"):
             if self.version.startswith("hybi"):
 
 
                 frame = self.decode_hybi(buf, base64=self.base64)
                 frame = self.decode_hybi(buf, base64=self.base64)
-                #print "Received buf: %s, frame: %s" % (repr(buf), frame)
+                #print("Received buf: %s, frame: %s" % (repr(buf), frame))
 
 
                 if frame['payload'] == None:
                 if frame['payload'] == None:
                     # Incomplete/partial frame
                     # Incomplete/partial frame
@@ -416,7 +447,7 @@ Sec-WebSocket-Accept: %s\r
                     buf = buf[2:]
                     buf = buf[2:]
                     continue # No-op
                     continue # No-op
 
 
-                elif buf.count('\xff') == 0:
+                elif buf.count(s2b('\xff')) == 0:
                     # Partial frame
                     # Partial frame
                     self.traffic("}.")
                     self.traffic("}.")
                     self.recv_part = buf
                     self.recv_part = buf
@@ -426,6 +457,13 @@ Sec-WebSocket-Accept: %s\r
 
 
             self.traffic("}")
             self.traffic("}")
 
 
+            if self.rec:
+                start = frame['hlen']
+                end = frame['hlen'] + frame['length']
+                self.rec.write("%s,\n" %
+                        repr("}%s}" % tdelta + buf[start:end]))
+
+
             bufs.append(frame['payload'])
             bufs.append(frame['payload'])
 
 
             if frame['left']:
             if frame['left']:
@@ -439,7 +477,7 @@ Sec-WebSocket-Accept: %s\r
         """ Send a WebSocket orderly close frame. """
         """ Send a WebSocket orderly close frame. """
 
 
         if self.version.startswith("hybi"):
         if self.version.startswith("hybi"):
-            msg = ''
+            msg = s2b('')
             if code != None:
             if code != None:
                 msg = struct.pack(">H%ds" % (len(reason)), code)
                 msg = struct.pack(">H%ds" % (len(reason)), code)
 
 
@@ -447,7 +485,7 @@ Sec-WebSocket-Accept: %s\r
             self.client.send(buf)
             self.client.send(buf)
 
 
         elif self.version == "hixie-76":
         elif self.version == "hixie-76":
-            buf = self.encode_hixie('\xff\x00')
+            buf = s2b('\xff\x00')
             self.client.send(buf)
             self.client.send(buf)
 
 
         # No orderly close for 75
         # No orderly close for 75
@@ -483,14 +521,16 @@ Sec-WebSocket-Accept: %s\r
         if handshake == "":
         if handshake == "":
             raise self.EClose("ignoring empty handshake")
             raise self.EClose("ignoring empty handshake")
 
 
-        elif handshake.startswith("<policy-file-request/>"):
+        elif handshake.startswith(s2b("<policy-file-request/>")):
             # Answer Flash policy request
             # Answer Flash policy request
             handshake = sock.recv(1024)
             handshake = sock.recv(1024)
-            sock.send(self.policy_response)
+            sock.send(s2b(self.policy_response))
             raise self.EClose("Sending flash policy response")
             raise self.EClose("Sending flash policy response")
 
 
         elif handshake[0] in ("\x16", "\x80"):
         elif handshake[0] in ("\x16", "\x80"):
             # SSL wrap the connection
             # SSL wrap the connection
+            if not ssl:
+                raise self.EClose("SSL connection but no 'ssl' module")
             if not os.path.exists(self.cert):
             if not os.path.exists(self.cert):
                 raise self.EClose("SSL connection but '%s' not found"
                 raise self.EClose("SSL connection but '%s' not found"
                                   % self.cert)
                                   % self.cert)
@@ -500,7 +540,8 @@ Sec-WebSocket-Accept: %s\r
                         server_side=True,
                         server_side=True,
                         certfile=self.cert,
                         certfile=self.cert,
                         keyfile=self.key)
                         keyfile=self.key)
-            except ssl.SSLError, x:
+            except ssl.SSLError:
+                _, x, _ = sys.exc_info()
                 if x.args[0] == ssl.SSL_ERROR_EOF:
                 if x.args[0] == ssl.SSL_ERROR_EOF:
                     raise self.EClose("")
                     raise self.EClose("")
                 else:
                 else:
@@ -517,29 +558,21 @@ Sec-WebSocket-Accept: %s\r
             scheme = "ws"
             scheme = "ws"
             stype = "Plain non-SSL (ws://)"
             stype = "Plain non-SSL (ws://)"
 
 
-        # Now get the data from the socket
-        handshake = retsock.recv(4096)
-
-        if len(handshake) == 0:
-            raise self.EClose("Client closed during handshake")
-
-        # Check for and handle normal web requests
-        if (handshake.startswith('GET ') and
-                handshake.find('Upgrade: WebSocket\r\n') == -1 and
-                handshake.find('Upgrade: websocket\r\n') == -1):
-            if not self.web:
-                raise self.EClose("Normal web request received but disallowed")
-            sh = SplitHTTPHandler(handshake, retsock, address)
-            if sh.last_code < 200 or sh.last_code >= 300:
-                raise self.EClose(sh.last_message)
-            elif self.verbose:
-                raise self.EClose(sh.last_message)
-            else:
-                raise self.EClose("")
+        wsh = WSRequestHandler(retsock, address, not self.web)
+        if wsh.last_code == 101:
+            # Continue on to handle WebSocket upgrade
+            pass
+        elif wsh.last_code == 405:
+            raise self.EClose("Normal web request received but disallowed")
+        elif wsh.last_code < 200 or wsh.last_code >= 300:
+            raise self.EClose(wsh.last_message)
+        elif self.verbose:
+            raise self.EClose(wsh.last_message)
+        else:
+            raise self.EClose("")
 
 
-        #self.msg("handshake: " + repr(handshake))
-        # Parse client WebSockets handshake
-        h = self.headers = self.parse_handshake(handshake)
+        h = self.headers = wsh.headers
+        path = self.path = wsh.path
 
 
         prot = 'WebSocket-Protocol'
         prot = 'WebSocket-Protocol'
         protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
         protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
@@ -548,8 +581,8 @@ Sec-WebSocket-Accept: %s\r
         if ver:
         if ver:
             # HyBi/IETF version of the protocol
             # HyBi/IETF version of the protocol
 
 
-            if not numpy or not ctypes:
-                self.EClose("Python numpy and ctypes modules required for HyBi-07 or greater")
+            if sys.hexversion < 0x2060000 or not numpy:
+                raise self.EClose("Python >= 2.6 and numpy module is required for HyBi-07 or greater")
 
 
             if ver == '7':
             if ver == '7':
                 self.version = "hybi-07"
                 self.version = "hybi-07"
@@ -567,7 +600,7 @@ Sec-WebSocket-Accept: %s\r
                 raise self.EClose("Client must support 'binary' or 'base64' protocol")
                 raise self.EClose("Client must support 'binary' or 'base64' protocol")
 
 
             # Generate the hash value for the accept header
             # Generate the hash value for the accept header
-            accept = b64encode(sha1(key + self.GUID).digest())
+            accept = b64encode(sha1(s2b(key + self.GUID)).digest())
 
 
             response = self.server_handshake_hybi % accept
             response = self.server_handshake_hybi % accept
             if self.base64:
             if self.base64:
@@ -592,7 +625,7 @@ Sec-WebSocket-Accept: %s\r
             self.base64 = True
             self.base64 = True
 
 
             response = self.server_handshake_hixie % (pre,
             response = self.server_handshake_hixie % (pre,
-                    h['Origin'], pre, scheme, h['Host'], h['path'])
+                    h['Origin'], pre, scheme, h['Host'], path)
 
 
             if 'base64' in protocols:
             if 'base64' in protocols:
                 response += "%sWebSocket-Protocol: base64\r\n" % pre
                 response += "%sWebSocket-Protocol: base64\r\n" % pre
@@ -606,7 +639,7 @@ Sec-WebSocket-Accept: %s\r
 
 
         # Send server WebSockets handshake response
         # Send server WebSockets handshake response
         #self.msg("sending response [%s]" % response)
         #self.msg("sending response [%s]" % response)
-        retsock.send(response)
+        retsock.send(s2b(response))
 
 
         # Return the WebSockets socket which may be SSL wrapped
         # Return the WebSockets socket which may be SSL wrapped
         return retsock
         return retsock
@@ -624,9 +657,8 @@ Sec-WebSocket-Accept: %s\r
         #self.vmsg("Running poll()")
         #self.vmsg("Running poll()")
         pass
         pass
 
 
-    def top_SIGCHLD(self, sig, stack):
-        # Reap zombies after calling child SIGCHLD handler
-        self.do_SIGCHLD(sig, stack)
+    def fallback_SIGCHLD(self, sig, stack):
+        # Reap zombies when using os.fork() (python 2.4)
         self.vmsg("Got SIGCHLD, reaping zombies")
         self.vmsg("Got SIGCHLD, reaping zombies")
         try:
         try:
             result = os.waitpid(-1, os.WNOHANG)
             result = os.waitpid(-1, os.WNOHANG)
@@ -636,14 +668,52 @@ Sec-WebSocket-Accept: %s\r
         except (OSError):
         except (OSError):
             pass
             pass
 
 
-    def do_SIGCHLD(self, sig, stack):
-        pass
-
     def do_SIGINT(self, sig, stack):
     def do_SIGINT(self, sig, stack):
         self.msg("Got SIGINT, exiting")
         self.msg("Got SIGINT, exiting")
         sys.exit(0)
         sys.exit(0)
 
 
-    def new_client(self, client):
+    def top_new_client(self, startsock, address):
+        """ Do something with a WebSockets client connection. """
+        # Initialize per client settings
+        self.send_parts = []
+        self.recv_part  = None
+        self.base64     = False
+        self.rec        = None
+        self.start_time = int(time.time()*1000)
+
+        # handler process
+        try:
+            try:
+                self.client = self.do_handshake(startsock, address)
+
+                if self.record:
+                    # Record raw frame data as JavaScript array
+                    fname = "%s.%s" % (self.record,
+                                        self.handler_id)
+                    self.msg("opening record file: %s" % fname)
+                    self.rec = open(fname, 'w+')
+                    self.rec.write("var VNC_frame_data = [\n")
+
+                self.new_client()
+            except self.EClose:
+                _, exc, _ = sys.exc_info()
+                # Connection was not a WebSockets connection
+                if exc.args[0]:
+                    self.msg("%s: %s" % (address[0], exc.args[0]))
+            except Exception:
+                _, exc, _ = sys.exc_info()
+                self.msg("handler exception: %s" % str(exc))
+                if self.verbose:
+                    self.msg(traceback.format_exc())
+        finally:
+            if self.rec:
+                self.rec.write("'EOF']\n")
+                self.rec.close()
+
+            if self.client and self.client != startsock:
+                self.client.close()
+
+    def new_client(self):
         """ Do something with a WebSockets client connection. """
         """ Do something with a WebSockets client connection. """
         raise("WebSocketServer.new_client() must be overloaded")
         raise("WebSocketServer.new_client() must be overloaded")
 
 
@@ -665,9 +735,11 @@ Sec-WebSocket-Accept: %s\r
 
 
         self.started()  # Some things need to happen after daemonizing
         self.started()  # Some things need to happen after daemonizing
 
 
-        # Reep zombies
-        signal.signal(signal.SIGCHLD, self.top_SIGCHLD)
+        # Allow override of SIGINT
         signal.signal(signal.SIGINT, self.do_SIGINT)
         signal.signal(signal.SIGINT, self.do_SIGINT)
+        if not Process:
+            # os.fork() (python 2.4) child reaper
+            signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD)
 
 
         while True:
         while True:
             try:
             try:
@@ -679,14 +751,17 @@ Sec-WebSocket-Accept: %s\r
                     try:
                     try:
                         self.poll()
                         self.poll()
 
 
-                        ready = select.select([lsock], [], [], 1)[0];
+                        ready = select.select([lsock], [], [], 1)[0]
                         if lsock in ready:
                         if lsock in ready:
                             startsock, address = lsock.accept()
                             startsock, address = lsock.accept()
                         else:
                         else:
                             continue
                             continue
-                    except Exception, exc:
+                    except Exception:
+                        _, exc, _ = sys.exc_info()
                         if hasattr(exc, 'errno'):
                         if hasattr(exc, 'errno'):
                             err = exc.errno
                             err = exc.errno
+                        elif hasattr(exc, 'args'):
+                            err = exc.args[0]
                         else:
                         else:
                             err = exc[0]
                             err = exc[0]
                         if err == errno.EINTR:
                         if err == errno.EINTR:
@@ -695,55 +770,67 @@ Sec-WebSocket-Accept: %s\r
                         else:
                         else:
                             raise
                             raise
 
 
-                    self.vmsg('%s: forking handler' % address[0])
-                    pid = os.fork()
-
-                    if pid == 0:
-                        # Initialize per client settings
-                        self.send_parts = []
-                        self.recv_part  = None
-                        self.base64     = False
-                        # handler process
-                        self.client = self.do_handshake(
-                                startsock, address)
-                        self.new_client()
+                    if Process:
+                        self.vmsg('%s: new handler Process' % address[0])
+                        p = Process(target=self.top_new_client,
+                                args=(startsock, address))
+                        p.start()
+                        # child will not return
                     else:
                     else:
-                        # parent process
-                        self.handler_id += 1
-
-                except self.EClose, exc:
-                    # Connection was not a WebSockets connection
-                    if exc.args[0]:
-                        self.msg("%s: %s" % (address[0], exc.args[0]))
-                except KeyboardInterrupt, exc:
+                        # python 2.4
+                        self.vmsg('%s: forking handler' % address[0])
+                        pid = os.fork()
+                        if pid == 0:
+                            # child handler process
+                            self.top_new_client(startsock, address)
+                            break  # child process exits
+
+                    # parent process
+                    self.handler_id += 1
+
+                except KeyboardInterrupt:
+                    _, exc, _ = sys.exc_info()
+                    print("In KeyboardInterrupt")
                     pass
                     pass
-                except Exception, exc:
+                except SystemExit:
+                    _, exc, _ = sys.exc_info()
+                    print("In SystemExit")
+                    break
+                except Exception:
+                    _, exc, _ = sys.exc_info()
                     self.msg("handler exception: %s" % str(exc))
                     self.msg("handler exception: %s" % str(exc))
                     if self.verbose:
                     if self.verbose:
                         self.msg(traceback.format_exc())
                         self.msg(traceback.format_exc())
 
 
             finally:
             finally:
-                if self.client and self.client != startsock:
-                    self.client.close()
                 if startsock:
                 if startsock:
                     startsock.close()
                     startsock.close()
 
 
-            if pid == 0:
-                break # Child process exits
-
 
 
-# HTTP handler with request from a string and response to a socket
-class SplitHTTPHandler(SimpleHTTPRequestHandler):
-    def __init__(self, req, resp, addr):
-        # Save the response socket
-        self.response = resp
+# HTTP handler with WebSocket upgrade support
+class WSRequestHandler(SimpleHTTPRequestHandler):
+    def __init__(self, req, addr, only_upgrade=False):
+        self.only_upgrade = only_upgrade # only allow upgrades
         SimpleHTTPRequestHandler.__init__(self, req, addr, object())
         SimpleHTTPRequestHandler.__init__(self, req, addr, object())
 
 
-    def setup(self):
-        self.connection = self.response
-        # Duck type request string to file object
-        self.rfile = StringIO(self.request)
-        self.wfile = self.connection.makefile('wb', self.wbufsize)
+    def do_GET(self):
+        if (self.headers.get('upgrade') and
+                self.headers.get('upgrade').lower() == 'websocket'):
+
+            if (self.headers.get('sec-websocket-key1') or
+                    self.headers.get('websocket-key1')):
+                # For Hixie-76 read out the key hash
+                self.headers.__setitem__('key3', self.rfile.read(8))
+
+            # Just indicate that an WebSocket upgrade is needed
+            self.last_code = 101
+            self.last_message = "101 Switching Protocols"
+        elif self.only_upgrade:
+            # Normal web request responses are disabled
+            self.last_code = 405
+            self.last_message = "405 Method Not Allowed"
+        else:
+            SimpleHTTPRequestHandler.do_GET(self)
 
 
     def send_response(self, code, message=None):
     def send_response(self, code, message=None):
         # Save the status code
         # Save the status code
@@ -754,4 +841,3 @@ class SplitHTTPHandler(SimpleHTTPRequestHandler):
         # Save instead of printing
         # Save instead of printing
         self.last_message = f % args
         self.last_message = f % args
 
 
-

+ 10 - 31
utils/websockify

@@ -2,7 +2,7 @@
 
 
 '''
 '''
 A WebSocket to TCP socket proxy with support for "wss://" encryption.
 A WebSocket to TCP socket proxy with support for "wss://" encryption.
-Copyright 2010 Joel Martin
+Copyright 2011 Joel Martin
 Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
 Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
 
 
 You can make a cert/key with openssl using:
 You can make a cert/key with openssl using:
@@ -74,7 +74,7 @@ Traffic Legend:
         WebSocketServer.__init__(self, *args, **kwargs)
         WebSocketServer.__init__(self, *args, **kwargs)
 
 
     def run_wrap_cmd(self):
     def run_wrap_cmd(self):
-        print "Starting '%s'" % " ".join(self.wrap_cmd)
+        print("Starting '%s'" % " ".join(self.wrap_cmd))
         self.wrap_times.append(time.time())
         self.wrap_times.append(time.time())
         self.wrap_times.pop(0)
         self.wrap_times.pop(0)
         self.cmd = subprocess.Popen(
         self.cmd = subprocess.Popen(
@@ -88,14 +88,14 @@ Traffic Legend:
         # Need to call wrapped command after daemonization so we can
         # Need to call wrapped command after daemonization so we can
         # know when the wrapped command exits
         # know when the wrapped command exits
         if self.wrap_cmd:
         if self.wrap_cmd:
-            print "  - proxying from %s:%s to '%s' (port %s)\n" % (
+            print("  - proxying from %s:%s to '%s' (port %s)\n" % (
                     self.listen_host, self.listen_port,
                     self.listen_host, self.listen_port,
-                    " ".join(self.wrap_cmd), self.target_port)
+                    " ".join(self.wrap_cmd), self.target_port))
             self.run_wrap_cmd()
             self.run_wrap_cmd()
         else:
         else:
-            print "  - proxying from %s:%s to %s:%s\n" % (
+            print("  - proxying from %s:%s to %s:%s\n" % (
                     self.listen_host, self.listen_port,
                     self.listen_host, self.listen_port,
-                    self.target_host, self.target_port)
+                    self.target_host, self.target_port))
 
 
     def poll(self):
     def poll(self):
         # If we are wrapping a command, check it's status
         # If we are wrapping a command, check it's status
@@ -118,7 +118,7 @@ Traffic Legend:
                 if (now - avg) < 10:
                 if (now - avg) < 10:
                     # 3 times in the last 10 seconds
                     # 3 times in the last 10 seconds
                     if self.spawn_message:
                     if self.spawn_message:
-                        print "Command respawning too fast"
+                        print("Command respawning too fast")
                         self.spawn_message = False
                         self.spawn_message = False
                 else:
                 else:
                     self.run_wrap_cmd()
                     self.run_wrap_cmd()
@@ -138,15 +138,6 @@ Traffic Legend:
         Called after a new WebSocket connection has been established.
         Called after a new WebSocket connection has been established.
         """
         """
 
 
-        self.rec = None
-        if self.record:
-            # Record raw frame data as a JavaScript compatible file
-            fname = "%s.%s" % (self.record,
-                                self.handler_id)
-            self.msg("opening record file: %s" % fname)
-            self.rec = open(fname, 'w+')
-            self.rec.write("var VNC_frame_data = [\n")
-
         # Connect to the target
         # Connect to the target
         self.msg("connecting to: %s:%s" % (
         self.msg("connecting to: %s:%s" % (
                  self.target_host, self.target_port))
                  self.target_host, self.target_port))
@@ -154,19 +145,17 @@ Traffic Legend:
         tsock.connect((self.target_host, self.target_port))
         tsock.connect((self.target_host, self.target_port))
 
 
         if self.verbose and not self.daemon:
         if self.verbose and not self.daemon:
-            print self.traffic_legend
+            print(self.traffic_legend)
 
 
         # Start proxying
         # Start proxying
         try:
         try:
             self.do_proxy(tsock)
             self.do_proxy(tsock)
         except:
         except:
             if tsock:
             if tsock:
+                tsock.shutdown(socket.SHUT_RDWR)
                 tsock.close()
                 tsock.close()
                 self.vmsg("%s:%s: Target closed" %(
                 self.vmsg("%s:%s: Target closed" %(
                     self.target_host, self.target_port))
                     self.target_host, self.target_port))
-            if self.rec:
-                self.rec.write("'EOF']\n")
-                self.rec.close()
             raise
             raise
 
 
     def do_proxy(self, target):
     def do_proxy(self, target):
@@ -177,11 +166,9 @@ Traffic Legend:
         c_pend = 0
         c_pend = 0
         tqueue = []
         tqueue = []
         rlist = [self.client, target]
         rlist = [self.client, target]
-        tstart = int(time.time()*1000)
 
 
         while True:
         while True:
             wlist = []
             wlist = []
-            tdelta = int(time.time()*1000) - tstart
 
 
             if tqueue: wlist.append(target)
             if tqueue: wlist.append(target)
             if cqueue or c_pend: wlist.append(self.client)
             if cqueue or c_pend: wlist.append(self.client)
@@ -212,11 +199,8 @@ Traffic Legend:
             if self.client in outs:
             if self.client in outs:
                 # Send queued target data to the client
                 # Send queued target data to the client
                 c_pend = self.send_frames(cqueue)
                 c_pend = self.send_frames(cqueue)
-                cqueue = []
 
 
-                #if self.rec:
-                #    self.rec.write("%s,\n" %
-                #            repr("{%s{" % tdelta + dat[1:-1]))
+                cqueue = []
 
 
 
 
             if self.client in ins:
             if self.client in ins:
@@ -224,11 +208,6 @@ Traffic Legend:
                 bufs, closed = self.recv_frames()
                 bufs, closed = self.recv_frames()
                 tqueue.extend(bufs)
                 tqueue.extend(bufs)
 
 
-                #if self.rec:
-                #    for b in bufs:
-                #        self.rec.write(
-                #                repr("}%s}%s" % (tdelta, b)) + ",\n")
-
                 if closed:
                 if closed:
                     # TODO: What about blocking on client socket?
                     # TODO: What about blocking on client socket?
                     self.send_close()
                     self.send_close()