Network Load balancer from scratch in Python
Have you ever wondered how web applications handle increasing traffic? As a software engineer, you might have heard of load balancers,
which play a crucial role in managing the distribution of requests to multiple servers.
This weekend, I decided to dive deep into socket programming and create a simple yet functional a load balancer from scratch in Python. In this blog post, I’ll walk you through the process.
What is a Load Balancer anyway ?
What is a load balancer anyway you may ask. Imagine you’ve built your own onlyfans clone, and it’s gaining traction.
To handle the increased traffic, you can’t infinitely scale your single server vertically. Instead, you buy multiple servers and host your site across them.
However, a new problem arises: how do you effectively utilize the resources of all these servers? The answer is a load balancer. It distributes incoming requests to different servers based on a selection criteria, like a simple round-robin algorithm.
Layer 4 vs. Layer 7 Load Balancers
A layer 7 load balancer works on you guessed it right, layer 7 of OSI model which is the application layer, so for every incoming request, it do multiple things like eliminating TLS, based on the data received redirect it to a specific server.
The whole flow goes like this: client -> load balancer (decodes request) -> server Loadbalancer essentially creates a new request to server. All the data in request body is accessible to it. So we can say it can do things a little smartly.
sequenceDiagram participant Client participant LoadBalancer participant Server Client->>LoadBalancer: Request R1 LoadBalancer->>Server: Request R2 Server->>LoadBalancer: Response R1 LoadBalancer->>Client: Response R2
- A layer 4 load balancer on the other hand works on (you guess), which is Transport Layer of OSI model. It instead of decoding the request, directly sends the packet to server. It basically act as a relay between client and server, which is faster but … dumb.
The code
Let’s dive into the code to build our Layer 4 load balancer. The first step is to listen to an address and receive data from clients. We will be using Python’s socket library for this.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind((HOST, PORT))
sock.listen()
print(f"Listening on port: {PORT}")
while True:
client_conn, client_addr = sock.accept()
with client_conn:
print(f"Connected by {client_addr}")
data = client_conn.recv(1024)
print(data.decode())
finally:
sock.close()
Now that we are receiving connections from clients, we need to somehow relay requests and responses between them.
def forward_request(source, destination):
print(f"Sending data from {source.getsockname()} to {destination.getsockname()}")
try:
while True:
data = source.recv(1024)
if len(data) == 0:
break
destination.send(data)
finally:
source.close()
destination.close()
We will be needing two separate threads for this (one for receiving client request and sending it to server, and one receiving server response and sending it to client)
c2b_thread = threading.Thread(target=forward_request, args=(client_conn, backend_conn))
b2c_thread = threading.Thread(target=forward_request, args=(backend_conn, client_conn))
c2b_thread.start()
b2c_thread.start()
c2b_thread.join()
b2c_thread.join()
Now comes the tricky part: what if one server dies? It would be foolish to relay requests to a dead server.
We need a heartbeat mechanism to periodically check if the server is up and running. If it isn’t, we should exclude it from our pool of servers. We must perform this operation without blocking our primary server, so we’ll create a separate thread for it.
def get_server_heart_beat(server):
try:
resp = requests.get(f"http://{server.host}:{server.port}/health")
return resp.text == "up"
except:
return False
def update_heartbeat(server, delay):
while True:
server_heart_beat = get_server_heart_beat(server)
server.update_health_status(server_heart_beat)
sleep(delay)
def check_health(servers):
threads = []
for s in servers:
t = threading.Thread(target=update_heartbeat, args=(s, 2))
threads.append(t)
t.start()
return threads
We are almost done just need to test our load balancer.
Full code can be found on my github