8
\$\begingroup\$

It's an extremely simply HTTP server in Python using the socket library, and a few others to get the MIME type, etc...

I've also avoided the ../../ vulnerability, although some of the code in the send_file function seems a bit weak.

It should also be PEP8 compliant, aside from maybe some trailing whitespace in comments.

import filetype
import socket
import _thread
class ServerSocket():
 ''' Recieves connections, parses the request and ships it off to a handler
 '''
 def __init__(self, address, handler=None, *args, **kwargs):
 ''' Creates a server socket and defines a handler for the server
 '''
 self.socket = socket.socket()
 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 self.socket.bind(address)
 self.handler_args = args
 self.handler_kwargs = kwargs
 if handler:
 self.handler = handler # The custom handler
 else:
 self.handler = Handler # The default handler
 def initialise(self, open_connections=5):
 ''' Initilises the server socket and has it listen for connections
 '''
 self.socket.listen(open_connections)
 self.listen()
 def parse(self, data):
 ''' Splits a packet into
 the request,
 the headers (which includes the request),
 and contents
 '''
 stringed = str(data, 'utf-8')
 split = stringed.split('\r\n\r\n')
 headers = split[0]
 if len(split) > 1:
 content = split[1]
 else:
 content = []
 request = headers.split(' ')[0]
 return request, headers, content
 def handle(self, client, address):
 ''' Parses the data and handles the request. It then closes the connection
 '''
 try:
 data = client.recv(1024)
 except ConnectionResetError:
 if self.handler_kwargs["logging"] is True:
 print(f'{address[0]} unexpectedly quit')
 client.close()
 return
 parsed = self.parse(data)
 handler = self.handler(self.handler_args, self.handler_kwargs)
 handler.handle(client, parsed, address)
 client.close()
 def listen(self):
 ''' Listens until a keyboard interrupt and handles each connection in a
 new thread
 '''
 try:
 while True:
 client_data = self.socket.accept()
 if self.handler_kwargs['logging'] is True:
 print(f'Connection from {client_data[1][0]}')
 _thread.start_new_thread(self.handle, client_data)
 except KeyboardInterrupt:
 self.socket.close()
class Handler():
 ''' Handles requests from the Server Socket
 '''
 def __init__(self, args, kwargs):
 self.args = args
 self.kwargs = kwargs
 def set_status(self, code, message):
 ''' Used to add a status line:
 - 'HTTP/1.1 200 OK' or 'HTTP/1.1 404 Not Found', etc...
 '''
 self.reply_headers = [f'HTTP/1.0 {code} {message}']
 def set_header(self, header, content):
 ''' Defines a custom header and adds it to the response
 '''
 self.reply_headers += [f'{header}: {content}']
 def response(self, content):
 ''' Adds to the content of the response
 '''
 if type(content) == str:
 self.reply_content += content.split('\n')
 else:
 self.reply_content += [content]
 def calculate_content_length(self):
 ''' Calculates the content length and adds it to the header
 '''
 length = len(self.reply_content) * 2
 lengths = [len(line) for line in self.reply_content]
 length += sum(lengths)
 self.set_header('Content-Length', length)
 def get_type(self, file_name):
 return filetype.guess('./public/'+file_name)
 def extract_file_name(self, file_name=None):
 if file_name:
 f_name = file_name[1:]
 else:
 f_name = self.request_status.split(' ')[1][1:]
 return f_name
 def send_file(self, file_name=None):
 if file_name is None:
 file_name = self.extract_file_name()
 if file_name == '':
 file_name = 'index.html'
 elif file_name[0] in './':
 self.set_status(403, "Forbidden")
 self.set_header('Content-Type', 'text/html')
 self.reply_content = ['<p>Error 403: Forbidden</p>']
 return
 try:
 with open('./public/'+file_name, 'rb') as file:
 file_contents = file.read()
 except FileNotFoundError:
 self.set_status(404, 'Not Found')
 self.set_header('Content-Type', 'text/html')
 self.reply_content = ['<p>Error 404: File not found</p>']
 return
 file_type = self.get_type(file_name)
 if file_type is not None:
 self.set_header('Content-Type', file_type.MIME)
 elif file_name.split('.')[-1] == 'html':
 self.set_header('Content-Type', 'text/html')
 else:
 self.set_header('Content-Type', 'text/txt')
 self.response(file_contents)
 def get_request_address(self):
 return self.address
 def parse_headers(self, headers):
 t = {}
 for header in headers[1:]:
 t[header.split(': ')[0]] = header.split(': ')[1]
 return t
 def reply(self):
 ''' Assembles the response and sends it to the client
 '''
 if self.reply_headers[0][0:4] != "HTTP":
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.reply_content = ['<p>Response Status unspecified</p>']
 self.calculate_content_length()
 message = '\r\n'.join(self.reply_headers)
 message += '\r\n\r\n'
 try:
 message += '\r\n'.join(self.reply_content)
 message += '\r\n'
 except TypeError:
 message = bytes(message, 'utf-8')
 message += b'\r\n'.join(self.reply_content)
 message += b'\r\n'
 try:
 if type(message) == str:
 self.client.send(bytes(message, 'utf-8'))
 else:
 self.client.send(message)
 except:
 pass
 def handle(self, client, parsed_data, address):
 ''' Initialises variables and case-switches the request type to
 determine the handler function
 '''
 self.client = client
 self.address = address
 self.reply_headers = []
 self.reply_content = []
 self.headers = True
 self.request_status = parsed_data[1].split('\r\n')[0]
 request = parsed_data[0]
 headers = self.parse_headers(parsed_data[1].split('\r\n'))
 contents = parsed_data[2]
 if request == "GET":
 func = self.get
 elif request == "POST":
 func = self.post
 elif request == "HEAD":
 func = self.head
 elif request == "PUT":
 func = self.put
 elif request == "DELETE":
 func = self.delete
 elif request == "CONNECT":
 func = self.connect
 elif request == "OPTIONS":
 func = self.options
 elif request == "TRACE":
 func = self.trace
 elif request == "PATCH":
 func = self.patch
 else:
 func = self.default
 func(headers, contents)
 self.reply()
 def default(self, headers, contents):
 ''' If the request is not known, defaults to this
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Unknown Request Type</p>''')
 def get(self, headers, contents):
 ''' Overwrite to customly handle GET requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a GET Request</p>''')
 def post(self, headers, contents):
 ''' Overwrite to customly handle POST requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a POST Request</p>''')
 def head(self, headers, contents):
 ''' Overwrite to customly handle HEAD requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a HEAD Request</p>''')
 def put(self, headers, contents):
 ''' Overwrite to customly handle PUT requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a PUT Request</p>''')
 def delete(self, headers, contents):
 ''' Overwrite to customly handle DELETE requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a DELETE Request</p>''')
 def connect(self, headers, contents):
 ''' Overwrite to customly handle CONNECT requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a CONNECT Request</p>''')
 def options(self, headers, contents):
 ''' Overwrite to customly handle OPTIONS requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got an OPTIONS Request</p>''')
 def trace(self, headers, contents):
 ''' Overwrite to customly handle TRACE requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a TRACE Request</p>''')
 def patch(self, headers, contents):
 ''' Overwrite to customly handle PATCH requests
 '''
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.response('''<p>Successfully got a PATCH Request</p>''')
# =======================================================================
if __name__ == "__main__":
 import sys
 class CustomHandler(Handler):
 def get(self, headers, contents):
 self.set_status(200, 'OK')
 request_address = self.get_request_address()[0]
 file_name = self.extract_file_name()
 print(f'{request_address} -> {headers["Host"]}/{file_name}')
 self.send_file()
 def run():
 if len(sys.argv) == 2:
 port = int(sys.argv[1])
 else:
 port = 80
 try:
 print('Initialising...', end='')
 http_server = ServerSocket(
 ('0.0.0.0', port),
 CustomHandler,
 logging=True
 )
 print('Done')
 http_server.initialise()
 except Exception as e:
 print(f'{e}')
 run()
200_success
145k22 gold badges190 silver badges478 bronze badges
asked Apr 13, 2019 at 0:03
\$\endgroup\$

1 Answer 1

9
\$\begingroup\$

ServerSocket

self.handler

The handler evaluation in __init__ can be accomplished with an or ternary operation. It's clearer to the reader as to what's going on. Also, the name could be changed to HandlerClass, since it represents a class rather than an instance:

 self.HandlerClass = handler or Handler

self.handler logging

The check against handler_kwargs makes the code a bit difficult to follow, since the SocketServer is now in charge of something that arguably the handler should be doing. If the server is who should be doing the logging, then I would leave the logging check out of the handler_kwargs entirely. Store self.logging as a boolean setting on the server instance, and just check against that:

# for traceback handling
import traceback 
class SocketServer:
 def __init__(self, address, handler=None, *args, **kwargs):
 ~skipping some code~
 # this is so you don't unpack logging from kwargs since
 # it looks like you need it in your handler
 self.logging = kwargs.get('logging', False)
 ...
 def handle(self):
 try:
 data = client.recv(1024)
 except ConnectionResetError as e:
 # use the more pythonic `if bool` check
 # here, rather than comparing against a singleton
 if self.logging:
 print(f'{address[0]} unexpectedly quit: {e}', file=sys.stderr)
 traceback.print_exception(*sys.exc_info(), file=sys.stderr) 

I've added a sys.stderr stream to your print statement, and added the exception to your print statement. I've also added a traceback print which points to stderr.

handle Refactor

The client.close(); return statement can also be refactored using try/except's else feature and a finally block, since you always want the client to close

 try:
 data = client.recv(1024)
 except ConnectionResetError as e:
 if self.logging:
 print(f'{address[0]} unexpectedly quit: {e}', file=sys.stderr)
 traceback.print_exception(*sys.exc_info(), file=sys.stderr) 
 # this is if no exception fired
 else:
 parsed = self.parse(data)
 handler = self.HandlerClass()
 handler.handle(client, parsed, address)
 # this will always execute
 finally:
 client.close()

Convert bytes to str

parse

First, when converting bytes to str, use bytes.decode() rather than str(byte_obj, encoding). The headers and content evaluation can be handled using argument unpacking, and from there you can use an or expression with any on contents to either take the first result or create an empty list:

 def parse(self, data):
 ''' Splits a packet into
 the request,
 the headers (which includes the request),
 and contents
 '''
 stringed = data.decode().split('\r\n\r\n')
 # The headers and content blocks can be handled also with argument
 # unpacking and a ternary operator on content:
 headers, *content = data.decode().split('\r\n\r\n')
 content = content[0] if content else []
 request = headers.split(' ')[0]
 return request, headers, content

Handler

Adding to lists

There are lots of cases where you do this:

self.some_list += [some_value]

Just use .append(some_value), it's faster since you don't have to create a list just to add the value to an existing list. For example:

 def set_header(self, header, content):
 ''' Defines a custom header and adds it to the response
 '''
 # this is much quicker, do this instead
 self.reply_headers.append(f'{header}: {content}')

type-checking

Use isinstance rather than type(object) == some_type:

# This
 def response(self, content):
 ''' Adds to the content of the response
 '''
 if type(content) == str:
 self.reply_content += content.split('\n')
 else:
 self.reply_content += [content]
# should be this
 def response(self, content):
 if isinstance(content, str):
 self.reply_content.extend(content.split('\n'))
 else:
 self.reply_content.append(content)

Note that I'm also switching the addition of lists to appropriate calls to append and extend. Though, looking through the rest of your code, you only use this against str or bytes types, so I'd refactor to the following:

 def response(self, content):
 # force everything to bytes-type for convenience, that way
 # you never have to worry about TypeErrors later
 content = content if isinstance(content, bytes) else content.encode()
 self.reply_content.append(content)

content-length calculation

Here, you can drop the creation of the lengths list, and just use sum on map:

 def calculate_content_length(self):
 ''' Calculates the content length and adds it to the header
 '''
 length = len(self.reply_content) * 2
 # sum will take any iterable, even generators and maps
 # len in map is the function to be applied to each element
 length += sum(map(len, self.reply_content))
 self.set_header('Content-Length', length)

It's faster and doesn't build as many objects inside the function

Magic Numbers

In your extract_file_name function, you are slicing the file name from the first element, though it's not completely clear why:

 def extract_file_name(self, file_name=None):
 if file_name:
 f_name = file_name[1:]
 else:
 f_name = self.request_status.split(' ')[1][1:]
 return f_name

This is usually a code smell and you should include a docstring and/or comments to explain why you slice that way. Otherwise, the index is a "magic number" and can be difficult for you or others to maintain later. You also never use the file_name argument at any point when calling this function, so I might just leave it out.

Sending Files

Most of the improvements here follow the ones that have been suggested above:

Ternary Operation or Use or for file_name

file_name = file_name or self.extract_file_name()
## or 
file_name = file_name if file_name else self.extract_file_name()

The latter mirrors your if statement a bit more

Checking an empty string

Use if some_string not if some_string == ''.

Checking string start and end values

The str.startswith and str.endswith methods will help here, and they avoid indexing or slicing which can impair readability:

if file_name.startswith('./'):
 # do something
if file_name.endswith('html'):
 # do something

The logic in send_file could be cleaned up a bit. First, I think a handle_error method would do nicely to clean up some of the repeated code where you handle exceptions:

 def handle_error(self, code, short_reason, reason):
 """
 code: integer error code
 short_reason: string denoting the short error reason
 reason: string denoting the full error reason
 example:
 self.handle_error(404, 'Not Found', 'File Not Found')
 """
 self.set_status(code, short_reason)
 self.set_header('Content-Type', 'text/html')
 self.reply_content = [f'<p>Error {code}: {reason} </p>'] 

Next, I think the filename checking can be refactored to make a bit more sense. First, it seems a bit counterintuitive that a path that might start with '.' is forbidden. What about ./public/index.html? If you're trying to avoid folder traversal such as paths with . and .. in them, it might not be that bad to use a regex. For example, what if I tried to give you a path like `../../root_file.txt'? It would pass your test, even though it will traverse back directories. I would do something like the following:

import re
~skipping lots of code~
 @staticmethod
 def valid_path(path):
 """
 Will take any unix-path and check for any '.' and '..' directories
 in it. Example:
 import re
 re_path = re.compile('^\.+$')
 some_path = '/root/path/to/../folder/./file.txt'
 public_path = 'public/folder/../../../file.txt'
 next(filter(re_path.match, some_path.split('/')))
 # '..'
 '/'.join(filter(re_path.match, some_path.split('/')))
 # '..'
 valid_path = '/path/to/file.txt'
 next(filter(re_path.match, valid_path.split('/')))
 StopIteration
 """
 path = path.lstrip('.').lstrip('/')
 re_path = re.compile('^\.+$')
 try:
 match = next(filter(re_path.match, path.split('/')))
 except StopIteration:
 return True
 return False

This way you can check if there are backout paths without colliding with something inocuous like ./public/file.txt, and you can handle the following FileNotFoundError. Otherwise, you'll get a None on return, and then you can return an AccessDenied:

 def send_file(self, file_name=None):
 ~snip~
 # this is more pythonic than doing
 # if file_name == ''
 file_name = file_name or 'index.html'
 # will either be True or False
 if not self.valid_path(file_name):
 # I've added the keywords here for clarity
 self.handle_error(403, short_reason='Forbidden', reason='Forbidden')
 return 
 try:
 # use and f-string here rather than string
 # concatenation
 with open(f'./public/{file_name}', 'rb') as fh:
 file_contents = file.read()
 except FileNotFoundError:
 self.handle_error(404, short_reason='Not Found', reason='File Not Found')
 return
 file_type = self.get_type(file_name)
 if file_type is not None:
 self.set_header('Content-Type', file_type.MIME)
 # use str.endswith here
 elif file_name.endswith('html'):
 self.set_header('Content-Type', 'text/html')
 else:
 self.set_header('Content-Type', 'text/txt')
 self.response(file_contents)

Redundant Methods

The method get_request_address in my opinion doesn't need to be there. Just access self.address wherever you call it.

parse_headers

If it were up to me, I'd make the headers a dictionary from the get-go, but this can be re-factored like so:

 # this can be static since you never access any self attributes
 @staticmethod
 def parse_headers(headers):
 t = {}
 for header in headers[1:]:
 # unpack k and v from one call to split
 k, v = header.split(': ')
 t[k] = v
 return t

Or, even more succinctly

 @staticmethod
 def parse_headers(headers):
 return dict(header.split(': ') for header in headers[1:])

reply

Again, I'd force everything to bytes. You try to append to a str (message) with mixed types, which will fail:

 def reply(self):
 ''' Assembles the response and sends it to the client
 '''
 if not self.reply_headers[0].startswith("HTTP"):
 self.set_status(200, 'OK')
 self.set_header('Content-Type', 'text/html')
 self.reply_content = ['<p>Response Status unspecified</p>']
 self.calculate_content_length()
 # here's how to coerce to bytes on reply_headers
 message = b'\r\n'.join(map(str.encode, self.reply_headers))
 message += b'\r\n\r\n'
 message += b'\r\n'.join(x.encode() if isinstance(x, str) else x for x in self.reply_content))
 message += b'\r\n'
 # Now you don't have to type-check
 try:
 self.client.send(message)
 except:
 pass

Avoiding the type-checking entirely makes your code more streamlined and maintainable.

Last, instead of using pass on an unexpected Exception, I'd raise some sort of 500:

 try:
 self.client.send(message)
 except Exception as e:
 traceback.print_exception(*sys.exc_info(), file=sys.stderr)
 self.handle_error(500, 'Server Error', 'Internal Server Error')

Then at the very least there's some visibility either by the client or by you that something bad happened.

handle

This is a classic case of don't repeat yourself. Looking at the big block of if statements, your execution times will be slightly worse for a PATCH than a GET. To refactor, I would use a dictionary that binds the names of functions to operations, yielding a constant time lookup:

 def __init__(self, ...):
 ~snip~
 self.functions = {
 'GET': self.get,
 'POST': self.post,
 'HEAD': self.head,
 'PUT': self.put,
 'DELETE': self.delete,
 'CONNECT': self.connect,
 'OPTIONS': self.options,
 'TRACE': self.trace,
 'PATCH': self.patch
 }
 def call_method(self, contents, request_type=None, headers=None):
 headers = headers or {}
 headers['Content-Type'] = headers.get('Content-Type', 'text/html')
 self.set_status(200, 'OK')
 for k, v in headers.items():
 self.set_header(k, v)
 func = self.functions.get(request_type.upper() if request_type else '', self.default)
 func(headers, contents)
 self.reply()
answered Oct 7, 2019 at 17:37
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.