It's an extremely simply HTTP server in Python using the socket library, and a few others to get the MIME type, etc...
I've also avoided the ../../
vulnerability, although some of the code in the send_file
function seems a bit weak.
It should also be PEP8 compliant, aside from maybe some trailing whitespace in comments.
import filetype
import socket
import _thread
class ServerSocket():
''' Recieves connections, parses the request and ships it off to a handler
'''
def __init__(self, address, handler=None, *args, **kwargs):
''' Creates a server socket and defines a handler for the server
'''
self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(address)
self.handler_args = args
self.handler_kwargs = kwargs
if handler:
self.handler = handler # The custom handler
else:
self.handler = Handler # The default handler
def initialise(self, open_connections=5):
''' Initilises the server socket and has it listen for connections
'''
self.socket.listen(open_connections)
self.listen()
def parse(self, data):
''' Splits a packet into
the request,
the headers (which includes the request),
and contents
'''
stringed = str(data, 'utf-8')
split = stringed.split('\r\n\r\n')
headers = split[0]
if len(split) > 1:
content = split[1]
else:
content = []
request = headers.split(' ')[0]
return request, headers, content
def handle(self, client, address):
''' Parses the data and handles the request. It then closes the connection
'''
try:
data = client.recv(1024)
except ConnectionResetError:
if self.handler_kwargs["logging"] is True:
print(f'{address[0]} unexpectedly quit')
client.close()
return
parsed = self.parse(data)
handler = self.handler(self.handler_args, self.handler_kwargs)
handler.handle(client, parsed, address)
client.close()
def listen(self):
''' Listens until a keyboard interrupt and handles each connection in a
new thread
'''
try:
while True:
client_data = self.socket.accept()
if self.handler_kwargs['logging'] is True:
print(f'Connection from {client_data[1][0]}')
_thread.start_new_thread(self.handle, client_data)
except KeyboardInterrupt:
self.socket.close()
class Handler():
''' Handles requests from the Server Socket
'''
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs
def set_status(self, code, message):
''' Used to add a status line:
- 'HTTP/1.1 200 OK' or 'HTTP/1.1 404 Not Found', etc...
'''
self.reply_headers = [f'HTTP/1.0 {code} {message}']
def set_header(self, header, content):
''' Defines a custom header and adds it to the response
'''
self.reply_headers += [f'{header}: {content}']
def response(self, content):
''' Adds to the content of the response
'''
if type(content) == str:
self.reply_content += content.split('\n')
else:
self.reply_content += [content]
def calculate_content_length(self):
''' Calculates the content length and adds it to the header
'''
length = len(self.reply_content) * 2
lengths = [len(line) for line in self.reply_content]
length += sum(lengths)
self.set_header('Content-Length', length)
def get_type(self, file_name):
return filetype.guess('./public/'+file_name)
def extract_file_name(self, file_name=None):
if file_name:
f_name = file_name[1:]
else:
f_name = self.request_status.split(' ')[1][1:]
return f_name
def send_file(self, file_name=None):
if file_name is None:
file_name = self.extract_file_name()
if file_name == '':
file_name = 'index.html'
elif file_name[0] in './':
self.set_status(403, "Forbidden")
self.set_header('Content-Type', 'text/html')
self.reply_content = ['<p>Error 403: Forbidden</p>']
return
try:
with open('./public/'+file_name, 'rb') as file:
file_contents = file.read()
except FileNotFoundError:
self.set_status(404, 'Not Found')
self.set_header('Content-Type', 'text/html')
self.reply_content = ['<p>Error 404: File not found</p>']
return
file_type = self.get_type(file_name)
if file_type is not None:
self.set_header('Content-Type', file_type.MIME)
elif file_name.split('.')[-1] == 'html':
self.set_header('Content-Type', 'text/html')
else:
self.set_header('Content-Type', 'text/txt')
self.response(file_contents)
def get_request_address(self):
return self.address
def parse_headers(self, headers):
t = {}
for header in headers[1:]:
t[header.split(': ')[0]] = header.split(': ')[1]
return t
def reply(self):
''' Assembles the response and sends it to the client
'''
if self.reply_headers[0][0:4] != "HTTP":
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.reply_content = ['<p>Response Status unspecified</p>']
self.calculate_content_length()
message = '\r\n'.join(self.reply_headers)
message += '\r\n\r\n'
try:
message += '\r\n'.join(self.reply_content)
message += '\r\n'
except TypeError:
message = bytes(message, 'utf-8')
message += b'\r\n'.join(self.reply_content)
message += b'\r\n'
try:
if type(message) == str:
self.client.send(bytes(message, 'utf-8'))
else:
self.client.send(message)
except:
pass
def handle(self, client, parsed_data, address):
''' Initialises variables and case-switches the request type to
determine the handler function
'''
self.client = client
self.address = address
self.reply_headers = []
self.reply_content = []
self.headers = True
self.request_status = parsed_data[1].split('\r\n')[0]
request = parsed_data[0]
headers = self.parse_headers(parsed_data[1].split('\r\n'))
contents = parsed_data[2]
if request == "GET":
func = self.get
elif request == "POST":
func = self.post
elif request == "HEAD":
func = self.head
elif request == "PUT":
func = self.put
elif request == "DELETE":
func = self.delete
elif request == "CONNECT":
func = self.connect
elif request == "OPTIONS":
func = self.options
elif request == "TRACE":
func = self.trace
elif request == "PATCH":
func = self.patch
else:
func = self.default
func(headers, contents)
self.reply()
def default(self, headers, contents):
''' If the request is not known, defaults to this
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Unknown Request Type</p>''')
def get(self, headers, contents):
''' Overwrite to customly handle GET requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a GET Request</p>''')
def post(self, headers, contents):
''' Overwrite to customly handle POST requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a POST Request</p>''')
def head(self, headers, contents):
''' Overwrite to customly handle HEAD requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a HEAD Request</p>''')
def put(self, headers, contents):
''' Overwrite to customly handle PUT requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a PUT Request</p>''')
def delete(self, headers, contents):
''' Overwrite to customly handle DELETE requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a DELETE Request</p>''')
def connect(self, headers, contents):
''' Overwrite to customly handle CONNECT requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a CONNECT Request</p>''')
def options(self, headers, contents):
''' Overwrite to customly handle OPTIONS requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got an OPTIONS Request</p>''')
def trace(self, headers, contents):
''' Overwrite to customly handle TRACE requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a TRACE Request</p>''')
def patch(self, headers, contents):
''' Overwrite to customly handle PATCH requests
'''
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.response('''<p>Successfully got a PATCH Request</p>''')
# =======================================================================
if __name__ == "__main__":
import sys
class CustomHandler(Handler):
def get(self, headers, contents):
self.set_status(200, 'OK')
request_address = self.get_request_address()[0]
file_name = self.extract_file_name()
print(f'{request_address} -> {headers["Host"]}/{file_name}')
self.send_file()
def run():
if len(sys.argv) == 2:
port = int(sys.argv[1])
else:
port = 80
try:
print('Initialising...', end='')
http_server = ServerSocket(
('0.0.0.0', port),
CustomHandler,
logging=True
)
print('Done')
http_server.initialise()
except Exception as e:
print(f'{e}')
run()
1 Answer 1
ServerSocket
self.handler
The handler evaluation in __init__
can be accomplished with an or
ternary operation. It's clearer to the reader as to what's going on. Also, the name could be changed to HandlerClass
, since it represents a class
rather than an instance:
self.HandlerClass = handler or Handler
self.handler
logging
The check against handler_kwargs
makes the code a bit difficult to follow, since the SocketServer
is now in charge of something that arguably the handler
should be doing. If the server is who should be doing the logging, then I would leave the logging check out of the handler_kwargs
entirely. Store self.logging
as a boolean setting on the server instance, and just check against that:
# for traceback handling
import traceback
class SocketServer:
def __init__(self, address, handler=None, *args, **kwargs):
~skipping some code~
# this is so you don't unpack logging from kwargs since
# it looks like you need it in your handler
self.logging = kwargs.get('logging', False)
...
def handle(self):
try:
data = client.recv(1024)
except ConnectionResetError as e:
# use the more pythonic `if bool` check
# here, rather than comparing against a singleton
if self.logging:
print(f'{address[0]} unexpectedly quit: {e}', file=sys.stderr)
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
I've added a sys.stderr
stream to your print statement, and added the exception to your print statement. I've also added a traceback print which points to stderr.
handle
Refactor
The client.close(); return
statement can also be refactored using try/except
's else
feature and a finally
block, since you always want the client to close
try:
data = client.recv(1024)
except ConnectionResetError as e:
if self.logging:
print(f'{address[0]} unexpectedly quit: {e}', file=sys.stderr)
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
# this is if no exception fired
else:
parsed = self.parse(data)
handler = self.HandlerClass()
handler.handle(client, parsed, address)
# this will always execute
finally:
client.close()
Convert bytes
to str
parse
First, when converting bytes to str, use bytes.decode()
rather than str(byte_obj, encoding)
. The headers
and content
evaluation can be handled using argument unpacking, and from there you can use an or
expression with any
on contents
to either take the first result or create an empty list:
def parse(self, data):
''' Splits a packet into
the request,
the headers (which includes the request),
and contents
'''
stringed = data.decode().split('\r\n\r\n')
# The headers and content blocks can be handled also with argument
# unpacking and a ternary operator on content:
headers, *content = data.decode().split('\r\n\r\n')
content = content[0] if content else []
request = headers.split(' ')[0]
return request, headers, content
Handler
Adding to lists
There are lots of cases where you do this:
self.some_list += [some_value]
Just use .append(some_value)
, it's faster since you don't have to create a list just to add the value to an existing list. For example:
def set_header(self, header, content):
''' Defines a custom header and adds it to the response
'''
# this is much quicker, do this instead
self.reply_headers.append(f'{header}: {content}')
type-checking
Use isinstance
rather than type(object) == some_type
:
# This
def response(self, content):
''' Adds to the content of the response
'''
if type(content) == str:
self.reply_content += content.split('\n')
else:
self.reply_content += [content]
# should be this
def response(self, content):
if isinstance(content, str):
self.reply_content.extend(content.split('\n'))
else:
self.reply_content.append(content)
Note that I'm also switching the addition of lists to appropriate calls to append
and extend
. Though, looking through the rest of your code, you only use this against str
or bytes
types, so I'd refactor to the following:
def response(self, content):
# force everything to bytes-type for convenience, that way
# you never have to worry about TypeErrors later
content = content if isinstance(content, bytes) else content.encode()
self.reply_content.append(content)
content-length calculation
Here, you can drop the creation of the lengths
list, and just use sum
on map
:
def calculate_content_length(self):
''' Calculates the content length and adds it to the header
'''
length = len(self.reply_content) * 2
# sum will take any iterable, even generators and maps
# len in map is the function to be applied to each element
length += sum(map(len, self.reply_content))
self.set_header('Content-Length', length)
It's faster and doesn't build as many objects inside the function
Magic Numbers
In your extract_file_name
function, you are slicing the file name from the first element, though it's not completely clear why:
def extract_file_name(self, file_name=None):
if file_name:
f_name = file_name[1:]
else:
f_name = self.request_status.split(' ')[1][1:]
return f_name
This is usually a code smell and you should include a docstring and/or comments to explain why you slice that way. Otherwise, the index is a "magic number" and can be difficult for you or others to maintain later. You also never use the file_name
argument at any point when calling this function, so I might just leave it out.
Sending Files
Most of the improvements here follow the ones that have been suggested above:
Ternary Operation or Use or
for file_name
file_name = file_name or self.extract_file_name()
## or
file_name = file_name if file_name else self.extract_file_name()
The latter mirrors your if
statement a bit more
Checking an empty string
Use if some_string
not if some_string == ''
.
Checking string start and end values
The str.startswith
and str.endswith
methods will help here, and they avoid indexing or slicing which can impair readability:
if file_name.startswith('./'):
# do something
if file_name.endswith('html'):
# do something
The logic in send_file
could be cleaned up a bit. First, I think a handle_error
method would do nicely to clean up some of the repeated code where you handle exceptions:
def handle_error(self, code, short_reason, reason):
"""
code: integer error code
short_reason: string denoting the short error reason
reason: string denoting the full error reason
example:
self.handle_error(404, 'Not Found', 'File Not Found')
"""
self.set_status(code, short_reason)
self.set_header('Content-Type', 'text/html')
self.reply_content = [f'<p>Error {code}: {reason} </p>']
Next, I think the filename checking can be refactored to make a bit more sense. First, it seems a bit counterintuitive that a path that might start with '.'
is forbidden. What about ./public/index.html
? If you're trying to avoid folder traversal such as paths with .
and ..
in them, it might not be that bad to use a regex. For example, what if I tried to give you a path like `../../root_file.txt'? It would pass your test, even though it will traverse back directories. I would do something like the following:
import re
~skipping lots of code~
@staticmethod
def valid_path(path):
"""
Will take any unix-path and check for any '.' and '..' directories
in it. Example:
import re
re_path = re.compile('^\.+$')
some_path = '/root/path/to/../folder/./file.txt'
public_path = 'public/folder/../../../file.txt'
next(filter(re_path.match, some_path.split('/')))
# '..'
'/'.join(filter(re_path.match, some_path.split('/')))
# '..'
valid_path = '/path/to/file.txt'
next(filter(re_path.match, valid_path.split('/')))
StopIteration
"""
path = path.lstrip('.').lstrip('/')
re_path = re.compile('^\.+$')
try:
match = next(filter(re_path.match, path.split('/')))
except StopIteration:
return True
return False
This way you can check if there are backout paths without colliding with something inocuous like ./public/file.txt
, and you can handle the following FileNotFoundError
. Otherwise, you'll get a None
on return, and then you can return an AccessDenied
:
def send_file(self, file_name=None):
~snip~
# this is more pythonic than doing
# if file_name == ''
file_name = file_name or 'index.html'
# will either be True or False
if not self.valid_path(file_name):
# I've added the keywords here for clarity
self.handle_error(403, short_reason='Forbidden', reason='Forbidden')
return
try:
# use and f-string here rather than string
# concatenation
with open(f'./public/{file_name}', 'rb') as fh:
file_contents = file.read()
except FileNotFoundError:
self.handle_error(404, short_reason='Not Found', reason='File Not Found')
return
file_type = self.get_type(file_name)
if file_type is not None:
self.set_header('Content-Type', file_type.MIME)
# use str.endswith here
elif file_name.endswith('html'):
self.set_header('Content-Type', 'text/html')
else:
self.set_header('Content-Type', 'text/txt')
self.response(file_contents)
Redundant Methods
The method get_request_address
in my opinion doesn't need to be there. Just access self.address
wherever you call it.
parse_headers
If it were up to me, I'd make the headers a dictionary from the get-go, but this can be re-factored like so:
# this can be static since you never access any self attributes
@staticmethod
def parse_headers(headers):
t = {}
for header in headers[1:]:
# unpack k and v from one call to split
k, v = header.split(': ')
t[k] = v
return t
Or, even more succinctly
@staticmethod
def parse_headers(headers):
return dict(header.split(': ') for header in headers[1:])
reply
Again, I'd force everything to bytes
. You try to append to a str
(message
) with mixed types, which will fail:
def reply(self):
''' Assembles the response and sends it to the client
'''
if not self.reply_headers[0].startswith("HTTP"):
self.set_status(200, 'OK')
self.set_header('Content-Type', 'text/html')
self.reply_content = ['<p>Response Status unspecified</p>']
self.calculate_content_length()
# here's how to coerce to bytes on reply_headers
message = b'\r\n'.join(map(str.encode, self.reply_headers))
message += b'\r\n\r\n'
message += b'\r\n'.join(x.encode() if isinstance(x, str) else x for x in self.reply_content))
message += b'\r\n'
# Now you don't have to type-check
try:
self.client.send(message)
except:
pass
Avoiding the type-checking entirely makes your code more streamlined and maintainable.
Last, instead of using pass
on an unexpected Exception, I'd raise some sort of 500
:
try:
self.client.send(message)
except Exception as e:
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
self.handle_error(500, 'Server Error', 'Internal Server Error')
Then at the very least there's some visibility either by the client or by you that something bad happened.
handle
This is a classic case of don't repeat yourself. Looking at the big block of if
statements, your execution times will be slightly worse for a PATCH
than a GET
. To refactor, I would use a dictionary that binds the names of functions to operations, yielding a constant time lookup:
def __init__(self, ...):
~snip~
self.functions = {
'GET': self.get,
'POST': self.post,
'HEAD': self.head,
'PUT': self.put,
'DELETE': self.delete,
'CONNECT': self.connect,
'OPTIONS': self.options,
'TRACE': self.trace,
'PATCH': self.patch
}
def call_method(self, contents, request_type=None, headers=None):
headers = headers or {}
headers['Content-Type'] = headers.get('Content-Type', 'text/html')
self.set_status(200, 'OK')
for k, v in headers.items():
self.set_header(k, v)
func = self.functions.get(request_type.upper() if request_type else '', self.default)
func(headers, contents)
self.reply()