Build an Inference Subnet
Welcome to the Hypertensor protocol quickstart guide! In this walkthrough, we’ll create a simple inference protocol from start to finish. Whether you’re a seasoned developer or just starting out, this guide has got you covered.
What You’ll Achieve
By the end of this quickstart, you’ll have built a protocol by:
Load a model from HuggingFace
Register RPC methods for nodes to call inference on your node
Be able to call an inference on other nodes
Deploying the protocol in a distributed intelligence application
Set Up Your Development Environment
Clone the subnet template
git clone https://github.com/hypertensor-blockchain/subnet-template.gitFind the
/protocolsdirectory insubnet/app/protocolsand create a new file calledinference_protocol.pyandinference_model.py.In
inference_model.py, copy and paste the following code:
import asyncio
import threading
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, Union
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
TextIteratorStreamer,
)
from transformers import logging as transformers_logging
from subnet.utils.logging import get_logger
transformers_logging.set_verbosity_info()
logger = get_logger(__name__)
"""
Used internally for validation
"""
def set_seed(seed: int = 42):
import random, numpy as np # noqa: E401, I001
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
class InferenceModel:
def __init__(
self,
model_name_or_path: str,
device: Optional[Union[str, torch.device]] = None,
):
logger.info(f"self.device {self.device}")
self.model_name_or_path = model_name_or_path
logger.info(f"Loading {self.model_name_or_path} tokenizer...")
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_name_or_path)
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
self.model.eval()
async def stream_infer(
self,
inputs: torch.Tensor,
max_new_tokens: int = 5,
do_sample: bool = True
) -> AsyncIterator[torch.Tensor]:
"""
Stream one token at a time as they're generated by the model in real time.
"""
print("stream_infer.")
logger.info("stream_infer.")
queue: asyncio.Queue[Optional[torch.Tensor]] = asyncio.Queue()
loop = asyncio.get_event_loop()
# Move inputs to device
inputs = inputs.to(self.device)
attention_mask = torch.ones_like(inputs, dtype=torch.long)
# Setup TextStreamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=False,
)
# Generation thread
def generate():
try:
with torch.no_grad():
self.model.generate(
input_ids=inputs,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=do_sample,
use_cache=True,
streamer=streamer,
)
except Exception as e:
logger.exception(f"Exception during model.generate: {e}", exc_info=True)
loop.call_soon_threadsafe(queue.put_nowait, None)
# Streaming thread
def stream_worker():
try:
for decoded_token in streamer:
token_ids = self.tokenizer.encode(decoded_token, add_special_tokens=False)
for tid in token_ids:
print("tid", tid)
loop.call_soon_threadsafe(queue.put_nowait, torch.tensor([tid]))
except Exception as e:
logger.exception(f"Exception in stream_worker: {e}", exc_info=True)
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)
# Start both threads
threading.Thread(target=generate, daemon=True).start()
threading.Thread(target=stream_worker, daemon=True).start()
# Yield tokens as they arrive
while True:
token = await queue.get()
if token is None:
break
logger.debug(f"Yielding token: {token}")
yield token
@dataclass
class InferenceTask:
tensors: torch.Tensor
future: asyncio.Future
@dataclass(order=True)
class PrioritizedTask:
priority: int
count: int # Tie-breaker to ensure FIFO order for same priority
task: InferenceTask = field(compare=False)
class AsyncInferenceServer:
def __init__(self, model: InferenceModel):
self.model = model
self.queue: asyncio.PriorityQueue[PrioritizedTask] = asyncio.PriorityQueue()
self._counter = 0 # For FIFO ordering of equal-priority tasks
async def start(self):
asyncio.create_task(self._worker())
async def submit(self, tensors: torch.Tensor, priority: int = 10) -> AsyncIterator[torch.Tensor]:
"""
Submit an inference task to the queue with an optional priority level.
This method packages a tensor input into an `InferenceTask` and places it
into an internal priority queue for asynchronous processing. Tasks with
lower priority values (e.g., 0) are processed before those with higher
values (e.g., 10). This allows important tasks (e.g., internal validator
jobs) to cut ahead of less urgent ones.
Args:
tensors (torch.Tensor): The input tensor to be processed by the model.
priority (int, optional): A numeric priority value for the task. Lower
values mean higher priority. Defaults to 10.
Returns:
AsyncIterator[torch.Tensor]: An asynchronous iterator that yields
inference outputs from the model.
Example:
stream = await model.submit(my_input_tensor, priority=5)
async for token in stream:
print(token)
"""
logger.debug(f"Task submitted to async inference server, priority={priority}")
print(f"Task submitted to async inference server, priority={priority}")
future = asyncio.get_event_loop().create_future()
task = InferenceTask(tensors=tensors, future=future)
self._counter += 1
prioritized_task = PrioritizedTask(priority=priority, count=self._counter, task=task)
await self.queue.put(prioritized_task)
return await future
async def _worker(self):
logger.debug("Priority queue worker started...")
while True:
try:
prioritized_task: PrioritizedTask = await self.queue.get()
task = prioritized_task.task
logger.debug(f"Processing task with priority {prioritized_task.priority}")
try:
stream = self.model.stream_infer(task.tensors)
task.future.set_result(stream)
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
task.future.set_exception(e)
finally:
self.queue.task_done()
except Exception as e:
logger.error(f"Worker error: {e}", exc_info=True)
This will load the model from Hugging Face and allow others to call inference that you will return as a stream.
4. In inference_protocol.py start by creating the InferenceProtocol class:
Let's add a way to start the protocol and load the model. Add:
This will start the protocol and register the RPC methods we will later add, in the child process.
We now need a way to gather the RPC methods so others can call inference on your node, and so you can call inference on other nodes. Add:
Now, let's add the inference methods. The call_inference_stream will allow others to call inference on you, and you to call inference on others. The rpc_inference_stream method is the RPC method registered to the DHT that others can gather to call an inference request on you, and for you to call on others.
Add:
Let's put it all together, here's the final results:
Starting Your Protocol
In the /server directory in subnet/app/server navigate to server.py and find the run() function in the Server class.
Replace MockProtocol with your newly developed protocol.
Because we are testing, remove or comment out the ConsensusThread class that is under the MockProtocol class. This way, you can run this locally without requiring the blockchain.
Setting up your environment
If you're running a native Linux environment, you can skip this.
If you're not using a native Linux environment, such as WSL, you will need to route your IP and ports to your WSL environment.
Steps:
Create the
.wslconfigfile in your user's home directory in Windows with the following contents:
In WSL, find out the
inetIP address of your WSL container (172.X.X.X):
In Windows (PowerShell), allow traffic to be routed into the WSL container (replace
172.X.X.Xwith the IP address (inet) from step 2):
Set up your firewall (e.g., Windows Defender) to allow traffic from the outside world to port 31330/tcp.
You can also add an inbound rule for the port in the Windows Defender Firewall with Advanced Security settings.
Type in the Windows search bar "Windows Defender Firewall with Advanced Security"
Go to Inbound Rules
Create a new rule for ports (TCP) 33130 and 33131 to allow traffic
If you have a router, set it up to allow connections from the outside world (port 31330/tcp) to your computer (port 31330/tcp).
This is usually not required
In Windows PowerShell, you can run netsh interface portproxy show all. The WSL IP and port of choice should be shown if you did this correctly. This is the IP address you will use to start the node.
Start The Subnet
We're going to use a small but powerful 1B parameter model, TinyLlama/TinyLlama-1.1B-Chat-v1.0.
For both nodes, replace the 127.0.0.1 placeholder IP address with your IP address. If you're using WSL, use the IP address we got when running netsh interface portproxy show all.
The subnet template comes with 2 test RSA private keys in the root, server2.id and server3.id.
Start your first node:
We start this node using port 31330. This will be used as the bootnode node for the second node to connect to.
Keep this node running and open a new CLI tab.
Start the second node (in a separate CLI)
We start this node using port 31331.
We now have a fully decentralized network of nodes hosting models where nodes can call inference on each other.
Last updated