../

Network Load balancer from scratch in Python

Have you ever wondered how web applications handle increasing traffic? As a software engineer, you might have heard of load balancers,
which play a crucial role in managing the distribution of requests to multiple servers.

This weekend, I decided to dive deep into socket programming and create a simple yet functional a load balancer from scratch in Python. In this blog post, I’ll walk you through the process.

What is a Load Balancer anyway ?

What is a load balancer anyway you may ask. Imagine you’ve built your own onlyfans clone, and it’s gaining traction.

To handle the increased traffic, you can’t infinitely scale your single server vertically. Instead, you buy multiple servers and host your site across them.

However, a new problem arises: how do you effectively utilize the resources of all these servers? The answer is a load balancer. It distributes incoming requests to different servers based on a selection criteria, like a simple round-robin algorithm.

Layer 4 vs. Layer 7 Load Balancers

sequenceDiagram
    participant Client
    participant LoadBalancer
    participant Server
    Client->>LoadBalancer: Request R1
    LoadBalancer->>Server: Request R2
    Server->>LoadBalancer: Response R1
    LoadBalancer->>Client: Response R2

The code

Let’s dive into the code to build our Layer 4 load balancer. The first step is to listen to an address and receive data from clients. We will be using Python’s socket library for this.

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        sock.bind((HOST, PORT))
        sock.listen()
        print(f"Listening on port: {PORT}")
        while True:
            client_conn, client_addr = sock.accept()
            with client_conn:
                print(f"Connected by {client_addr}")
                data = client_conn.recv(1024)
                print(data.decode())
    finally:
        sock.close()

Now that we are receiving connections from clients, we need to somehow relay requests and responses between them.

    def forward_request(source, destination):
        print(f"Sending data from {source.getsockname()} to {destination.getsockname()}")
        try:
            while True:
                data = source.recv(1024)
                if len(data) == 0:
                    break
                destination.send(data)
        finally:
            source.close()
            destination.close()

We will be needing two separate threads for this (one for receiving client request and sending it to server, and one receiving server response and sending it to client)

    c2b_thread = threading.Thread(target=forward_request, args=(client_conn, backend_conn))
    b2c_thread = threading.Thread(target=forward_request, args=(backend_conn, client_conn))
    c2b_thread.start()
    b2c_thread.start()
    c2b_thread.join()
    b2c_thread.join()

Now comes the tricky part: what if one server dies? It would be foolish to relay requests to a dead server.
We need a heartbeat mechanism to periodically check if the server is up and running. If it isn’t, we should exclude it from our pool of servers. We must perform this operation without blocking our primary server, so we’ll create a separate thread for it.

    def get_server_heart_beat(server):
        try:
            resp = requests.get(f"http://{server.host}:{server.port}/health")
            return resp.text == "up"
        except:
            return False

    def update_heartbeat(server, delay):
        while True:
            server_heart_beat = get_server_heart_beat(server)
            server.update_health_status(server_heart_beat)
            sleep(delay)

    def check_health(servers):
        threads = []
        for s in servers:
            t = threading.Thread(target=update_heartbeat, args=(s, 2))
            threads.append(t)
            t.start()
        return threads

We are almost done just need to test our load balancer. xyx

Full code can be found on my github

/python/ /networking/ /codingchallenges/