# Serving Machine Learning Models

## Model Serving

The process of deploying trained machine learning models to production environment to make predictions on new data. This involves making the model accessible via an API or a GUI, allowing users or applications to send data to the model and receive predictions in return.

## Inference Strategies

Deployment strategies refer to the methods and practices used to deploy machine learning models into production. These strategies can vary based on the requirements of the application, the infrastructure available, and the scale at which the model needs to operate. Common deployment strategies include:

- **Batch Inference**: Running predictions on a batch of data at scheduled intervals. This is suitable for applications where real-time predictions are not required.
- **Online Inference**: Making predictions in real-time as data comes in. This is suitable for applications that require immediate responses, such as recommendation systems or fraud detection.
- **Streaming Inference**: Continuously processing and making predictions on a stream of data, such as sensor data or user interactions. This is suitable for applications that require real-time analysis of continuous data streams.
- **Edge Inference**: Deploying models on edge devices (e.g., IoT devices, mobile phones) to reduce latency and bandwidth usage. This is suitable for applications that require low-latency responses and can operate with limited resources.

| Feature            | **Batch Inference**                            | **Online Inference**                 | **Streaming Inference**                        | **Edge Inference**                        |
| ------------------ | ---------------------------------------------- | ------------------------------------ | ---------------------------------------------- | ----------------------------------------- |
| **Trigger**        | Scheduled (e.g., daily, hourly)                | On-demand (API call)                 | On data arrival in stream                      | Local event or sensor input               |
| **Input Type**     | Large dataset at once                          | Single or few inputs                 | Continuous data stream                         | Input from local device                   |
| **Latency**        | High (minutes to hours)                        | Low (ms–s)                           | Low to moderate (near real-time)               | Very low (ms)                             |
| **Throughput**     | Very high                                      | Low to moderate                      | High                                           | Varies (device-limited)                   |
| **Use Cases**      | Analytics, periodic predictions (e.g., churn)  | Web apps, chatbots, image APIs       | Fraud detection, log monitoring, IoT pipelines | Autonomous vehicles, drones, offline apps |
| **Deployment**     | Server, cloud job (e.g., Airflow, SageMaker)   | REST API, gRPC (FastAPI, Ray Serve)  | Stream processors (Kafka, Flink, Spark)        | Embedded device, mobile, microcontroller  |
| **Example**        | Predict next week’s demand from entire dataset | Classify one uploaded image          | Classify every transaction in real time        | Detect person using mobile camera         |
| **Model Refresh**  | Periodic (daily, weekly)                       | On model redeploy                    | Stream model updates possible                  | Often static (limited updates)            |
| **Resource Needs** | High (GPU/TPU batch servers)                   | Scalable APIs (autoscaling possible) | Stream processing engines + serving infra      | Limited CPU/RAM, no GPU                   |

| Type          | Focus              | Best For                                            |
| ------------- | ------------------ | --------------------------------------------------- |
| **Batch**     | **Scale**          | Periodic bulk jobs, forecasts, analytics            |
| **Online**    | **Responsiveness** | Real-time user apps (low-latency, request/response) |
| **Streaming** | **Reactivity**     | Always-on systems reacting to live data             |
| **Edge**      | **Locality**       | Low-latency, disconnected, on-device inference      |

## Tools for Model Serving

Different tools can be used to deploy and serve machine learning models as scale, including:

- **FastAPI**: a high-performance web framework to build REST APIs for serving models.
- **Flask**: a lightweight web framework to build REST APIs for serving models.
- **TensorFlow Serving**: a specialized server for serving TensorFlow models.
- **TorchServe**: a specialized server for serving PyTorch models.
- **ONNX Runtime**: a cross-platform inference engine for ONNX models.
- **Seldon Core**: a Kubernetes-native platform to deploy and manage machine learning models.
- **BentoML**: a framework to package and deploy machine learning models as APIs.
- **LitServe**: a framework to serve ML models with minimal code and powerful features, such as batching, streaming, GPU acceleration, and autoscaling.

## LitServe

An opens-source model serving framework that provides a simple and efficient way to deploy machine learning models as APIs. It is designed to be easy to use, flexible, and scalable, making it suitable for both small and large-scale deployments. It is  built on top of FastAPI and provides a set of features to simplify the process of serving models, such as:

- **Automatic Batching**: Automatically batches incoming requests to improve throughput and reduce latency.
- **Streaming Responses**: Supports streaming responses for real-time applications.
- **GPU Acceleration**: Supports GPU acceleration for faster inference.
- **Autoscaling**: Automatically scales the number of workers based on the incoming traffic.

## Getting Started with LitServe

- Install LitServe using pip:

    ```powershell
    pip install litserve
    ```

- Create a new directory for your project and navigate to it:

    ```powershell
    mkdir demo_litserve
    cd demo_litserve
    ```

- Create a new Python file named `server.py` and add the following code:

    ```python
    import litserve as ls


    class SimpleLitAPI(ls.LitAPI):
        def setup(self, device):
            self.model1 = lambda x: x**2  # pylint:disable=W0201
            self.model2 = lambda x: x**3  # pylint:disable=W0201

        def decode_request(self, request):  # type: ignore pylint:disable=W0221
            return request["input"]

        def predict(self, x):  # type: ignore pylint:disable=W0221
            squared = self.model1(x)
            cubed = self.model2(x)
            output = squared + cubed
            return {"output": output}

        def encode_response(self, output):  # type: ignore pylint:disable=W0221
            return {"output": output}


    if __name__ == "__main__":
        api = SimpleLitAPI()
        server = ls.LitServer(api, accelerator="auto")
        server.run(port=8000)
    ```

- Deploy the server on your local machine:

    ```powershell
    lightning deploy server.py
    ```

- Proceed to the link `http://localhost:8000/docs` to access the API documentation and test the API.

- A `client.py` file will be automatically generated in the same directory to test the API. You can run it to see how the API works:

    ```powershell
    python client.py
    ```

- You can define PyDantic models for request and response validation.

    ```python
    from typing import Dict

    import litserve as ls
    from pydantic import BaseModel as PydanticBaseModel
    from fastapi.middleware.cors import CORSMiddleware

    class BaseModel(PydanticBaseModel):
        class config:
            arbitrary_types_allowed = True

    class SimpleLitRequest(BaseModel):
        input: float

    class SimpleLitResponse(BaseModel):
        output: float

    class SimpleLitAPI(ls.LitAPI):
        def setup(self, device):
            self.model1 = lambda x: x**2  # pylint:disable=W0201
            self.model2 = lambda x: x**3  # pylint:disable=W0201

        def decode_request(self, request: SimpleLitRequest):  # type: ignore pylint:disable=W0221
            return request.input

        def predict(self, x: float):  # type: ignore pylint:disable=W0221
            squared = self.model1(x)
            cubed = self.model2(x)
            output = squared + cubed
            return {"output": output}

        def encode_response(self, output: Dict) -> SimpleLitResponse:  # type: ignore pylint:disable=W0221
            return SimpleLitResponse(output=output["output"])

    if __name__ == "__main__":
        api = SimpleLitAPI()
        cors_middleware = (
        CORSMiddleware,
        {
            "allow_origins": ["*"],  # Allows all origins
            "allow_methods": ["GET", "POST"],  # Allows GET and POST methods
            "allow_headers": ["*"],  # Allows all headers
        },
        )
        server = ls.LitServer(api, accelerator="auto", middlewares=[cors_middleware])
        server.run(port=8000)
    ```

- Test the API

    ```python
    import requests

    response = requests.post("<http://127.0.0.1:8000/predict>", json={"input": 4.0})
    print(f"Status: {response.status_code}\nResponse:\n {response.text}")
    ```

- Enabling concurrency with `async` allows the server to handle multiple tasks seemlessly at the same time by switching between them without blocking the execution of other tasks. This is particularly useful for I/O-bound operations, such as making API calls or reading from a database, where the server can continue processing other requests while waiting for the I/O operation to complete.

    ```python
    from typing import Dict

    import litserve as ls
    from pydantic import BaseModel as PydanticBaseModel
    from fastapi.middleware.cors import CORSMiddleware

    class BaseModel(PydanticBaseModel):
        class config:
            arbitrary_types_allowed = True

    class SimpleLitRequest(BaseModel):
        input: float

    class SimpleLitResponse(BaseModel):
        output: float

    class AsyncLitAPI(ls.LitAPI):
        def setup(self, device):
            self.model1 = lambda x: x**2  # pylint:disable=W0201
            self.model2 = lambda x: x**3  # pylint:disable=W0201

        async def decode_request(self, request: SimpleLitRequest):  # type: ignore pylint:disable=W0221
            return request.input

        async def predict(self, x: float):  # type: ignore pylint:disable=W0221
            squared = self.model1(x)
            cubed = self.model2(x)
            output = squared + cubed
            return {"output": output}

        async def encode_response(self, output: Dict) -> SimpleLitResponse:  # type: ignore pylint:disable=W0221
            return SimpleLitResponse(output=output["output"])

    if __name__ == "__main__":
        api = AsyncLitAPI(enable_async=True)
        cors_middleware = (
        CORSMiddleware,
        {
            "allow_origins": ["*"],  # Allows all origins
            "allow_methods": ["GET", "POST"],  # Allows GET and POST methods
            "allow_headers": ["*"],  # Allows all headers
        },
        )
        server = ls.LitServer(api, accelerator="auto", middlewares=[cors_middleware])
        server.run(port=8000)

    ```

- Test the API

    ```python
    import httpx
    import asyncio

    async def main():
        async with httpx.AsyncClient() as client:
            response = await client.post("<http://localhost:8000>", json={"input": 4})
            print(response.json())

    asyncio.run(main())
    ```

- LitServe allows stream inference

    ```python
    from typing import Dict

    import litserve as ls
    from pydantic import BaseModel as PydanticBaseModel
    from fastapi.middleware.cors import CORSMiddleware

    class BaseModel(PydanticBaseModel):
        class Config:
            arbitrary_types_allowed = True

    class SimpleLitRequest(BaseModel):
        input: float

    class SimpleLitResponse(BaseModel):
        output: float

    class AsyncLitAPI(ls.LitAPI):
        def setup(self, device):
            self.model1 = lambda x: x**2  # pylint:disable=W0201
            self.model2 = lambda x: x**3  # pylint:disable=W0201

        async def decode_request(self, request: SimpleLitRequest):  # type: ignore pylint:disable=W0221
            return request.input

        async def predict(self, x: float):  # type: ignore pylint:disable=W0221
            for i in range(10):
                squared = self.model1(x + i)
                cubed = self.model2(x + i)
                output = squared + cubed
                yield {"output": output}

        async def encode_response(self, output: Dict) -> SimpleLitResponse:  # type: ignore pylint:disable=W0221
            output = output[0]
            yield SimpleLitResponse(output=output["output"])

    if __name__ == "__main__":
        api = AsyncLitAPI(enable_async=True)
        cors_middleware = (
        CORSMiddleware,
        {
            "allow_origins": ["*"],  # Allows all origins
            "allow_methods": ["GET", "POST"],  # Allows GET and POST methods
            "allow_headers": ["*"],  # Allows all headers
        },
        )
        server = ls.LitServer(
            api, accelerator="auto", middlewares=[cors_middleware], stream=True
        )
        server.run(port=8000)
    ```

- LitServe allow authentication and support FastAPI advanced custom authentication mechanisms, such as OAuth and HTTP Bearer.

- Autoscaling can be enabled within a machine or across multiple machines.

  ```python
    import litserve as ls
    if __name__ == "__main__":
        api = ls.test_examples.SimpleLitAPI()
        # When running on machine with 4 GPUs, these are equivalent
        # server = ls.LitServer(api)
        # server = ls.LitServer(api, devices="auto")
        server = ls.LitServer(api, devices=3)
        server.run(port=8000)
  ```

- You can also scale the API server

    ```python
    import litserve as ls

    if __name__ == "__main__":
        api = ls.test_examples.SimpleLitAPI()
        server = ls.LitServer(api, workers_per_device=2)
        # Run the server on port 8000 with 4 API servers running in separate processes
        server.run(port=8000, num_api_servers=4)
    ```

- LitServe enables request batching, by combining multiple incoming requests into a single batch to improve throughput and reduce latency. This is particularly useful where minimzing latency per request is less critical than maximizing overall throughput.

    ```python
    import numpy as np

    import litserve as ls

    class SimpleBatchedAPI(ls.LitAPI):
        def setup(self, device):
            self.model = lambda x: x ** 2

        def decode_request(self, request):
            return np.asarray(request["input"])

        def predict(self, x):
            result = self.model(x)
            return result

        def encode_response(self, output):
            return {"output": output}

    if __name__ == "__main__":
        api = SimpleBatchedAPI(max_batch_size=8, batch_timeout=0.05)
        server = ls.LitServer(api)
        server.run(port=8000)
    ```

- LitServe allows dockerizing your server using the command

  ```powershell
    litserve dockerize server.py --port 8000 --gpu
  ```

## ESA WolrdCover Classification Model Deployment Using LitServe

- First, make a new directory for the project and navigate to it:

    ```powershell
    mkdir eas_worldcover_litserve
    cd eas_worldcover_litserve
    ```

- Create new Python files for the model, the configurations and the server:

    `server.py`

    ```python
    from typing import Dict, List
    import torch
    import numpy as np
    import litserve as ls
    from fastapi.middleware.cors import CORSMiddleware
    from pydantic import BaseModel as PydanticBaseModel
    from model import UNet
    from config import IN_CHANNELS, OUT_CLASSES, MODEL_PATH

    class BaseModel(PydanticBaseModel):
        class Config:
            arbitrary_types_allowed = True

    class ESAWCRequest(BaseModel):
        image: List[List[List[float]]]

    class ESAWCResponse(BaseModel):
        output: List[List[int]]

    class ESAWorldCoverLitAPI(ls.LitAPI):
        def setup(self, device):
            self.model = UNet(  # pylint:disable=W0201
                in_channels=IN_CHANNELS, out_classes=OUT_CLASSES, dropout=0.0
            ).to(device)
            self.model.load_state_dict(torch.load(MODEL_PATH, weights_only=True,map_location="cpu"))
            self.model.eval()
            self.device = device

        def decode_request(self, request: ESAWCRequest):  # type: ignore pylint:disable=W0221,W0236
            image_array = np.array(request.image, dtype=np.float32)
            return image_array

        def predict(self, x: np.array):  # type: ignore pylint:disable=W0221,W0236
            x = x / 10000.0
            x = torch.from_numpy(x).float()
            x = x.clip(min=0.0, max=1.0)
            x = x.to(self.device)
            with torch.no_grad():
                prediction = self.model(x.unsqueeze(0))  # Add batch dimension
            mask = torch.argmax(prediction, dim=1)
            return {"output": mask.cpu().numpy()[0, ...].tolist()}

        def encode_response(self, output: Dict) -> ESAWCResponse:  # type: ignore pylint:disable=W0221,W0236
            return ESAWCResponse(output=output["output"])  # type: ignore

    if __name__ == "__main__":
        api = ESAWorldCoverLitAPI()
        cors_middleware = (
            CORSMiddleware,
            {
                "allow_origins": ["*"],  # Allows all origins
                "allow_methods": ["GET", "POST"],  # Allows GET and POST methods
                "allow_headers": ["*"],  # Allows all headers
            },
        )
        server = ls.LitServer(
            api,
            accelerator="auto",
            middlewares=[cors_middleware],
            devices=3,
        )
        server.run(port=8000)
    ```

- Test locally

    ```powershell
    lightning deploy server.py
    ```

- Dockerize the server using the command:

    ```powershell
    litserve dockerize server.py --port 8000
    ```

- Create `requirements.txt` file with the following content:

    ```plaintext
    torch==2.7.0
    litserve==0.2.10
    python-multipart
    numpy==2.2.4
    ```

- Update the `Dockerfile` to include the requirements file:

    ```dockerfile
    ARG PYTHON_VERSION=3.12

    FROM python:$PYTHON_VERSION-slim

    ####### Add your own installation commands here #######

    # RUN pip install some-package

    # RUN wget <https://path/to/some/data/or/weights>

    # RUN apt-get update && apt-get install -y <package-name>

    WORKDIR /app
    COPY . /app

    # Install litserve and requirements

    RUN pip install --no-cache-dir -r requirements.txt
    EXPOSE 8000
    CMD ["python", "/app/server.py"]

    ```

- Before building the image, make sure to comment `devices=3` in the `server.py` file.
- Build the Docker image:

    ```powershell
    docker build -t eas_litserve .
    ```

- Run the Docker container:

    ```powershell
    docker run -p 8000:8000 eas_litserve
    ```

- Test the API using the `test_esa_litserve_api.ipynb`
- Tag and push the docker image to Docker Hub:

    ```powershell
    docker tag eas_litserve albughdadim/eas_litserve:latest
    docker push albughdadim/eas_litserve:latest
    ```

- Deploy to Kubernetes using the `k8s/deployment.yml` file

```{literalinclude} esa_worldcover_litserve/k8s/deployment.yml
:language: yaml
:linenos:
```

- Apply the deployment to your Kubernetes cluster:

    ```powershell
    kubectl apply -f k8s/deployment.yml
    ```

- Expose the deployment using a service

```{literalinclude} esa_worldcover_litserve/k8s/service.yml
:language: yaml
:linenos:
```

- Apply the service to your Kubernetes cluster:

    ```powershell
    kubectl apply -f k8s/service.yml
    ```

- Expose the service using an Ingress controller

```{literalinclude} esa_worldcover_litserve/k8s/ingress.yml
:language: yaml
:linenos:
```

- Apply the ingress to your Kubernetes cluster:

    ```powershell
    kubectl apply -f k8s/ingress.yml
    ```

- Make sure to use your own `namespace` and `host` in the `k8s` yaml files.
- Test the API using the `test_esa_litserve_api.ipynb` file by changing the URL to the Ingress URL.
