Build an Inference Subnet

Welcome to the Hypertensor protocol quickstart guide! In this walkthrough, we’ll create a simple inference protocol from start to finish. Whether you’re a seasoned developer or just starting out, this guide has got you covered.

What You’ll Achieve

By the end of this quickstart, you’ll have built a protocol by:

  • Load a model from HuggingFace

  • Register RPC methods for nodes to call inference on your node

  • Be able to call an inference on other nodes

  • Deploying the protocol in a distributed intelligence application

Set Up Your Development Environment

  1. Clone the subnet template

git clone https://github.com/hypertensor-blockchain/subnet-template.git
  1. Find the /protocols directory in subnet/app/protocols and create a new file called inference_protocol.py and inference_model.py.

  2. In inference_model.py, copy and paste the following code:

import asyncio
import threading
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, Union

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    TextIteratorStreamer,
)
from transformers import logging as transformers_logging

from subnet.utils.logging import get_logger

transformers_logging.set_verbosity_info()

logger = get_logger(__name__)

"""
Used internally for validation
"""
def set_seed(seed: int = 42):
    import random, numpy as np  # noqa: E401, I001
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.use_deterministic_algorithms(True, warn_only=True)

class InferenceModel:
    def __init__(
        self,
        model_name_or_path: str,
        device: Optional[Union[str, torch.device]] = None,
    ):
        logger.info(f"self.device {self.device}")
        self.model_name_or_path = model_name_or_path
        logger.info(f"Loading {self.model_name_or_path} tokenizer...")
        self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
        self.model.eval()

    async def stream_infer(
        self,
        inputs: torch.Tensor,
        max_new_tokens: int = 5,
        do_sample: bool = True
    ) -> AsyncIterator[torch.Tensor]:
        """
        Stream one token at a time as they're generated by the model in real time.
        """
        print("stream_infer.")
        logger.info("stream_infer.")

        queue: asyncio.Queue[Optional[torch.Tensor]] = asyncio.Queue()
        loop = asyncio.get_event_loop()

        # Move inputs to device
        inputs = inputs.to(self.device)
        attention_mask = torch.ones_like(inputs, dtype=torch.long)

        # Setup TextStreamer
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=False,
        )

        # Generation thread
        def generate():
            try:
                with torch.no_grad():
                    self.model.generate(
                        input_ids=inputs,
                        attention_mask=attention_mask,
                        max_new_tokens=max_new_tokens,
                        eos_token_id=self.tokenizer.eos_token_id,
                        pad_token_id=self.tokenizer.eos_token_id,
                        do_sample=do_sample,
                        use_cache=True,
                        streamer=streamer,
                    )
            except Exception as e:
                logger.exception(f"Exception during model.generate: {e}", exc_info=True)
                loop.call_soon_threadsafe(queue.put_nowait, None)

        # Streaming thread
        def stream_worker():
            try:
                for decoded_token in streamer:
                    token_ids = self.tokenizer.encode(decoded_token, add_special_tokens=False)
                    for tid in token_ids:
                        print("tid", tid)
                        loop.call_soon_threadsafe(queue.put_nowait, torch.tensor([tid]))
            except Exception as e:
                logger.exception(f"Exception in stream_worker: {e}", exc_info=True)
            finally:
                loop.call_soon_threadsafe(queue.put_nowait, None)

        # Start both threads
        threading.Thread(target=generate, daemon=True).start()
        threading.Thread(target=stream_worker, daemon=True).start()

        # Yield tokens as they arrive
        while True:
            token = await queue.get()
            if token is None:
                break
            logger.debug(f"Yielding token: {token}")
            yield token

@dataclass
class InferenceTask:
    tensors: torch.Tensor
    future: asyncio.Future

@dataclass(order=True)
class PrioritizedTask:
    priority: int
    count: int  # Tie-breaker to ensure FIFO order for same priority
    task: InferenceTask = field(compare=False)

class AsyncInferenceServer:
    def __init__(self, model: InferenceModel):
        self.model = model
        self.queue: asyncio.PriorityQueue[PrioritizedTask] = asyncio.PriorityQueue()
        self._counter = 0  # For FIFO ordering of equal-priority tasks

    async def start(self):
        asyncio.create_task(self._worker())

    async def submit(self, tensors: torch.Tensor, priority: int = 10) -> AsyncIterator[torch.Tensor]:
        """
        Submit an inference task to the queue with an optional priority level.

        This method packages a tensor input into an `InferenceTask` and places it
        into an internal priority queue for asynchronous processing. Tasks with
        lower priority values (e.g., 0) are processed before those with higher
        values (e.g., 10). This allows important tasks (e.g., internal validator
        jobs) to cut ahead of less urgent ones.

        Args:
            tensors (torch.Tensor): The input tensor to be processed by the model.
            priority (int, optional): A numeric priority value for the task. Lower
                values mean higher priority. Defaults to 10.

        Returns:
            AsyncIterator[torch.Tensor]: An asynchronous iterator that yields
            inference outputs from the model.

        Example:
            stream = await model.submit(my_input_tensor, priority=5)
            async for token in stream:
                print(token)
        """
        logger.debug(f"Task submitted to async inference server, priority={priority}")
        print(f"Task submitted to async inference server, priority={priority}")

        future = asyncio.get_event_loop().create_future()
        task = InferenceTask(tensors=tensors, future=future)

        self._counter += 1
        prioritized_task = PrioritizedTask(priority=priority, count=self._counter, task=task)
        await self.queue.put(prioritized_task)

        return await future

    async def _worker(self):
        logger.debug("Priority queue worker started...")
        while True:
            try:
                prioritized_task: PrioritizedTask = await self.queue.get()
                task = prioritized_task.task
                logger.debug(f"Processing task with priority {prioritized_task.priority}")

                try:
                    stream = self.model.stream_infer(task.tensors)
                    task.future.set_result(stream)
                except Exception as e:
                    logger.error(f"Inference failed: {e}", exc_info=True)
                    task.future.set_exception(e)
                finally:
                    self.queue.task_done()
            except Exception as e:
                logger.error(f"Worker error: {e}", exc_info=True)

This will load the model from Hugging Face and allow others to call inference that you will return as a stream. 4. In inference_protocol.py start by creating the InferenceProtocol class:


Let's add a way to start the protocol and load the model. Add:

This will start the protocol and register the RPC methods we will later add, in the child process.


We now need a way to gather the RPC methods so others can call inference on your node, and so you can call inference on other nodes. Add:


Now, let's add the inference methods. The call_inference_stream will allow others to call inference on you, and you to call inference on others. The rpc_inference_stream method is the RPC method registered to the DHT that others can gather to call an inference request on you, and for you to call on others.

Add:


Let's put it all together, here's the final results:

Starting Your Protocol

In the /server directory in subnet/app/server navigate to server.py and find the run() function in the Server class.

Replace MockProtocol with your newly developed protocol.

Because we are testing, remove or comment out the ConsensusThread class that is under the MockProtocol class. This way, you can run this locally without requiring the blockchain.

Setting up your environment

If you're running a native Linux environment, you can skip this.

If you're not using a native Linux environment, such as WSL, you will need to route your IP and ports to your WSL environment.

Steps:

  1. Create the .wslconfig file in your user's home directory in Windows with the following contents:

  1. In WSL, find out the inet IP address of your WSL container (172.X.X.X):

  1. In Windows (PowerShell), allow traffic to be routed into the WSL container (replace 172.X.X.X with the IP address (inet) from step 2):

  1. Set up your firewall (e.g., Windows Defender) to allow traffic from the outside world to port 31330/tcp.

    1. You can also add an inbound rule for the port in the Windows Defender Firewall with Advanced Security settings.

      1. Type in the Windows search bar "Windows Defender Firewall with Advanced Security"

      2. Go to Inbound Rules

      3. Create a new rule for ports (TCP) 33130 and 33131 to allow traffic

  2. If you have a router, set it up to allow connections from the outside world (port 31330/tcp) to your computer (port 31330/tcp).

    1. This is usually not required

In Windows PowerShell, you can run netsh interface portproxy show all. The WSL IP and port of choice should be shown if you did this correctly. This is the IP address you will use to start the node.

circle-info

Note: We are not using the Windows IP address; we are using the WSL IP address. The WSL IP normally changes on each restart, but we locked the WSL IP from changing in step 1.


Start The Subnet

We're going to use a small but powerful 1B parameter model, TinyLlama/TinyLlama-1.1B-Chat-v1.0arrow-up-right.

For both nodes, replace the 127.0.0.1 placeholder IP address with your IP address. If you're using WSL, use the IP address we got when running netsh interface portproxy show all.

The subnet template comes with 2 test RSA private keys in the root, server2.id and server3.id.

Start your first node:

We start this node using port 31330. This will be used as the bootnode node for the second node to connect to.

Keep this node running and open a new CLI tab.

Start the second node (in a separate CLI)

We start this node using port 31331.

We now have a fully decentralized network of nodes hosting models where nodes can call inference on each other.

Last updated