Build an Inference Subnet
Welcome to the Hypertensor protocol quickstart guide! In this walkthrough, we’ll create a simple inference protocol from start to finish. Whether you’re a seasoned developer or just starting out, this guide has got you covered.
What You’ll Achieve
By the end of this quickstart, you’ll have built a protocol by:
Load a model from HuggingFace
Register RPC methods for nodes to call inference on your node
Be able to call an inference on other nodes
Deploying the protocol in a distributed intelligence application
Set Up Your Development Environment
Clone the mesh template
git clone https://github.com/hypertensor-blockchain/mesh.git
Find the
/protocols
directory inmesh/subnet/protocols
and create a new file calledinference_protocol.py
andinference_model.py
.In
inference_model.py
, copy and paste the following code:
import asyncio
import threading
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, Union
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
TextIteratorStreamer,
)
from transformers import logging as transformers_logging
from hivemind.utils.logging import get_logger
transformers_logging.set_verbosity_info()
logger = get_logger(__name__)
"""
Used internally for validation
"""
def set_seed(seed: int = 42):
import random, numpy as np # noqa: E401, I001
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
class InferenceModel:
def __init__(
self,
model_name_or_path: str,
device: Optional[Union[str, torch.device]] = None,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"self.device {self.device}")
logger.info(f"self.device {self.device}")
self.model_name_or_path = model_name_or_path
logger.info(f"Loading {self.model_name_or_path} tokenizer...")
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_name_or_path)
logger.info(f"Loading {self.model_name_or_path} model...")
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
# self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path)
logger.info("Moving to Cuda")
self.model.to(self.device)
logger.info("Moving to Evaluation Mode")
self.model.eval()
logger.info("Model loaded successfully.")
async def stream_infer(
self,
inputs: torch.Tensor,
max_new_tokens: int = 5,
do_sample: bool = True
) -> AsyncIterator[torch.Tensor]:
"""
Stream one token at a time as they're generated by the model in real time.
"""
print("stream_infer.")
logger.info("stream_infer.")
queue: asyncio.Queue[Optional[torch.Tensor]] = asyncio.Queue()
loop = asyncio.get_event_loop()
# Move inputs to device
inputs = inputs.to(self.device)
attention_mask = torch.ones_like(inputs, dtype=torch.long)
# Setup TextStreamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=False,
)
# Generation thread
def generate():
try:
with torch.no_grad():
self.model.generate(
input_ids=inputs,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=do_sample,
use_cache=True,
streamer=streamer,
)
except Exception as e:
logger.exception(f"Exception during model.generate: {e}", exc_info=True)
loop.call_soon_threadsafe(queue.put_nowait, None)
# Streaming thread
def stream_worker():
try:
for decoded_token in streamer:
token_ids = self.tokenizer.encode(decoded_token, add_special_tokens=False)
for tid in token_ids:
print("tid", tid)
loop.call_soon_threadsafe(queue.put_nowait, torch.tensor([tid]))
except Exception as e:
logger.exception(f"Exception in stream_worker: {e}", exc_info=True)
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)
# Start both threads
threading.Thread(target=generate, daemon=True).start()
threading.Thread(target=stream_worker, daemon=True).start()
# Yield tokens as they arrive
while True:
token = await queue.get()
if token is None:
break
logger.debug(f"Yielding token: {token}")
yield token
@dataclass
class InferenceTask:
tensors: torch.Tensor
future: asyncio.Future
@dataclass(order=True)
class PrioritizedTask:
priority: int
count: int # Tie-breaker to ensure FIFO order for same priority
task: InferenceTask = field(compare=False)
class AsyncInferenceServer:
def __init__(self, model: InferenceModel):
self.model = model
self.queue: asyncio.PriorityQueue[PrioritizedTask] = asyncio.PriorityQueue()
self._counter = 0 # For FIFO ordering of equal-priority tasks
async def start(self):
asyncio.create_task(self._worker())
async def submit(self, tensors: torch.Tensor, priority: int = 10) -> AsyncIterator[torch.Tensor]:
"""
Submit an inference task to the queue with an optional priority level.
This method packages a tensor input into an `InferenceTask` and places it
into an internal priority queue for asynchronous processing. Tasks with
lower priority values (e.g., 0) are processed before those with higher
values (e.g., 10). This allows important tasks (e.g., internal validator
jobs) to cut ahead of less urgent ones.
Args:
tensors (torch.Tensor): The input tensor to be processed by the model.
priority (int, optional): A numeric priority value for the task. Lower
values mean higher priority. Defaults to 10.
Returns:
AsyncIterator[torch.Tensor]: An asynchronous iterator that yields
inference outputs from the model.
Example:
stream = await model.submit(my_input_tensor, priority=5)
async for token in stream:
print(token)
"""
logger.debug(f"Task submitted to async inference server, priority={priority}")
print(f"Task submitted to async inference server, priority={priority}")
future = asyncio.get_event_loop().create_future()
task = InferenceTask(tensors=tensors, future=future)
self._counter += 1
prioritized_task = PrioritizedTask(priority=priority, count=self._counter, task=task)
await self.queue.put(prioritized_task)
return await future
async def _worker(self):
logger.debug("Priority queue worker started...")
while True:
try:
prioritized_task: PrioritizedTask = await self.queue.get()
task = prioritized_task.task
logger.debug(f"Processing task with priority {prioritized_task.priority}")
try:
stream = self.model.stream_infer(task.tensors)
task.future.set_result(stream)
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
task.future.set_exception(e)
finally:
self.queue.task_done()
except Exception as e:
logger.error(f"Worker error: {e}", exc_info=True)
This will load the model from Hugging Face and allow others to call inference that you will return as a stream.
4. In inference_protocol.py
start by creating the InferenceProtocol
class:
from __future__ import annotations
import asyncio
import io
import multiprocessing as mp
from typing import AsyncIterator, Optional
import torch
import mesh
from mesh import DHT, get_dht_time
from mesh.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from mesh.p2p import P2P, P2PContext, PeerID, ServicerBase
from mesh.proto import dht_pb2, inference_protocol_pb2, runtime_pb2
from mesh.subnet.protocols.inference_model import AsyncInferenceServer, InferenceModel
from mesh.subnet.utils.consensus import get_consensus_key
from mesh.subnet.utils.key import extract_rsa_peer_id, extract_rsa_peer_id_from_ssh
from mesh.utils import get_logger
from mesh.utils.asyncio import switch_to_uvloop
from mesh.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapperStreamer
from mesh.utils.mpfuture import MPFuture
from mesh.utils.serializer import MSGPackSerializer
logger = get_logger(__name__)
class InferenceProtocol(mp.context.ForkProcess, ServicerBase):
_async_model: AsyncInferenceServer
def __init__(
self,
dht: DHT,
subnet_id: int,
model_name: Optional[str] = None,
balanced: bool = True,
shutdown_timeout: float = 3,
authorizer: Optional[AuthorizerBase] = None,
start: bool = False,
):
super().__init__()
self.dht = dht
self.subnet_id = subnet_id
self.peer_id = dht.peer_id
self.node_id = dht.node_id
self.node_info = dht_pb2.NodeInfo(node_id=self.node_id.to_bytes())
self.balanced, self.shutdown_timeout = balanced, shutdown_timeout
self._p2p = None
self.authorizer = authorizer
self.ready = MPFuture()
self.rpc_semaphore = asyncio.Semaphore(float("inf"))
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
self.model_name = model_name
self.daemon = True
if start:
self.run_in_background(await_ready=True)
Let's add a way to start the protocol and load the model. Add:
def run(self):
torch.set_num_threads(1)
loop = switch_to_uvloop()
stop = asyncio.Event()
loop.add_reader(self._inner_pipe.fileno(), stop.set)
async def _run():
try:
self._p2p = await self.dht.replicate_p2p()
"""Add rpc_* methods from this class to the P2P servicer"""
if self.authorizer is not None:
logger.info("Adding P2P handlers with authorizer")
await self.add_p2p_handlers(
self._p2p,
AuthRPCWrapperStreamer(self, AuthRole.SERVICER, self.authorizer),
)
else:
await self.add_p2p_handlers(self._p2p, balanced=self.balanced)
if self.model_name is not None:
model = InferenceModel(self.model_name)
self._async_model = AsyncInferenceServer(model)
asyncio.create_task(self._async_model._worker())
self.ready.set_result(None)
except Exception as e:
logger.debug(e, exc_info=True)
self.ready.set_exception(e)
try:
await stop.wait()
finally:
await self.remove_p2p_handlers(self._p2p)
try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
"""
Starts InferenceProtocol in a background process. If :await_ready:, this method will wait until
it is ready to process incoming requests or for :timeout: seconds max.
"""
self.start()
This will start the protocol and register the RPC methods we will later add, in the child process.
We now need a way to gather the RPC methods so others can call inference on your node, and so you can call inference on other nodes. Add:
def get_stub(self, p2p: P2P, peer: PeerID) -> AuthRPCWrapperStreamer:
"""
Get a stub that sends requests to a given peer.
It's important here to wrap the stub with an authentication wrapper, see AuthRPCWrapper
"""
stub = super().get_stub(p2p, peer)
return AuthRPCWrapperStreamer(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
Now, let's add the inference methods. The call_inference_stream
will allow others to call inference on you, and you to call inference on others. The rpc_inference_stream
method is the RPC method registered to the DHT that others can gather to call an inference request on you, and for you to call on others.
Add:
async def call_inference_stream(
self, peer: PeerID, prompt: str, tensor: torch.Tensor
) -> AsyncIterator[torch.Tensor]:
"""
Call another peer to perform an inference stream on the `tensor`
The inference will be returned as a streamed
"""
input_stream = inference_protocol_pb2.InferenceRequestAuth(
input=prompt,
max_new_tokens=5,
tensor=serialize_torch_tensor(tensor),
)
try:
async with self.rpc_semaphore:
p2p = await self.dht.replicate_p2p()
response_stream = await self.get_stub(p2p, peer).rpc_inference_stream(input_stream)
async for response in response_stream:
for tensor_bytes in response.tensors:
tensor = deserialize_torch_tensor(tensor_bytes)
yield tensor
except Exception as e:
logger.error(f"InferenceProtocol failed to stream from {peer}: {e}", exc_info=True)
return
async def rpc_inference_stream(
self, requests: inference_protocol_pb2.InferenceRequestAuth, context: P2PContext
) -> AsyncIterator[inference_protocol_pb2.InferenceResponseAuth]:
"""
A peer wants us to perform an inference stream
"""
tensor = deserialize_torch_tensor(requests.tensor)
async for token_tensor in await self._async_model.submit(tensor):
yield inference_protocol_pb2.InferenceResponseAuth(
peer=self.node_info,
dht_time=get_dht_time(),
output=str(token_tensor.item()),
tensors=[serialize_torch_tensor(token_tensor)]
)
Let's put it all together, here's the final results:
from __future__ import annotations
import asyncio
import io
import multiprocessing as mp
from typing import AsyncIterator, Optional
import torch
import mesh
from mesh import DHT, get_dht_time
from mesh.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from mesh.p2p import P2P, P2PContext, PeerID, ServicerBase
from mesh.proto import dht_pb2, inference_protocol_pb2, runtime_pb2
from mesh.subnet.protocols.inference_model import AsyncInferenceServer, InferenceModel
from mesh.subnet.utils.consensus import get_consensus_key
from mesh.subnet.utils.key import extract_rsa_peer_id, extract_rsa_peer_id_from_ssh
from mesh.utils import get_logger
from mesh.utils.asyncio import switch_to_uvloop
from mesh.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapperStreamer
from mesh.utils.mpfuture import MPFuture
from mesh.utils.serializer import MSGPackSerializer
logger = get_logger(__name__)
class InferenceProtocol(mp.context.ForkProcess, ServicerBase):
_async_model: AsyncInferenceServer
def __init__(
self,
dht: DHT,
subnet_id: int,
model_name: Optional[str] = None,
balanced: bool = True,
shutdown_timeout: float = 3,
authorizer: Optional[AuthorizerBase] = None,
start: bool = False,
):
super().__init__()
self.dht = dht
self.subnet_id = subnet_id
self.peer_id = dht.peer_id
self.node_id = dht.node_id
self.node_info = dht_pb2.NodeInfo(node_id=self.node_id.to_bytes())
self.balanced, self.shutdown_timeout = balanced, shutdown_timeout
self._p2p = None
self.authorizer = authorizer
self.ready = MPFuture()
self.rpc_semaphore = asyncio.Semaphore(float("inf"))
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
self.model_name = model_name
self.daemon = True
if start:
self.run_in_background(await_ready=True)
def run(self):
torch.set_num_threads(1)
loop = switch_to_uvloop()
stop = asyncio.Event()
loop.add_reader(self._inner_pipe.fileno(), stop.set)
async def _run():
try:
self._p2p = await self.dht.replicate_p2p()
"""Add rpc_* methods from this class to the P2P servicer"""
if self.authorizer is not None:
logger.info("Adding P2P handlers with authorizer")
await self.add_p2p_handlers(
self._p2p,
AuthRPCWrapperStreamer(self, AuthRole.SERVICER, self.authorizer),
)
else:
await self.add_p2p_handlers(self._p2p, balanced=self.balanced)
if self.model_name is not None:
model = InferenceModel(self.model_name)
self._async_model = AsyncInferenceServer(model)
asyncio.create_task(self._async_model._worker())
self.ready.set_result(None)
except Exception as e:
logger.debug(e, exc_info=True)
self.ready.set_exception(e)
try:
await stop.wait()
finally:
await self.remove_p2p_handlers(self._p2p)
try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
"""
Starts InferenceProtocol in a background process. If :await_ready:, this method will wait until
it is ready to process incoming requests or for :timeout: seconds max.
"""
self.start()
def shutdown(self):
if self.is_alive():
self.join(self.shutdown_timeout)
if self.is_alive():
logger.warning(
"InferenceProtocol did not shut down within the grace period; terminating it the hard way"
)
self.terminate()
else:
logger.warning("InferenceProtocol shutdown had no effect, the process is already dead")
def get_stub(self, p2p: P2P, peer: PeerID) -> AuthRPCWrapperStreamer:
"""
Get a stub that sends requests to a given peer.
It's important here to wrap the stub with an authentication wrapper, see AuthRPCWrapper
"""
stub = super().get_stub(p2p, peer)
return AuthRPCWrapperStreamer(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
async def call_inference_stream(
self, peer: PeerID, prompt: str, tensor: torch.Tensor
) -> AsyncIterator[torch.Tensor]:
"""
Call another peer to perform an inference stream on the `tensor`
The inference will be returned as a streamed
"""
input_stream = inference_protocol_pb2.InferenceRequestAuth(
input=prompt,
max_new_tokens=5,
tensor=serialize_torch_tensor(tensor),
)
try:
async with self.rpc_semaphore:
p2p = await self.dht.replicate_p2p()
response_stream = await self.get_stub(p2p, peer).rpc_inference_stream(input_stream)
async for response in response_stream:
for tensor_bytes in response.tensors:
tensor = deserialize_torch_tensor(tensor_bytes)
yield tensor
except Exception as e:
logger.error(f"InferenceProtocol failed to stream from {peer}: {e}", exc_info=True)
return
async def rpc_inference_stream(
self, requests: inference_protocol_pb2.InferenceRequestAuth, context: P2PContext
) -> AsyncIterator[inference_protocol_pb2.InferenceResponseAuth]:
"""
A peer wants us to perform an inference stream
"""
tensor = deserialize_torch_tensor(requests.tensor)
async for token_tensor in await self._async_model.submit(tensor):
yield inference_protocol_pb2.InferenceResponseAuth(
peer=self.node_info,
dht_time=get_dht_time(),
output=str(token_tensor.item()),
tensors=[serialize_torch_tensor(token_tensor)]
)
Starting Your Protocol
In the /server
directory in mesh/subnet/server
navigate to server.py
and find the run()
function in the Server class.
Replace MockProtocol
with your newly developed protocol.
Because we are testing, remove or comment out the ConsensusThread
class that is under the MockProtocol
class. This way, you can run this locally without requiring the blockchain.
Setting up your environment
If you're running a native Linux environment, you can skip this.
If you're not using a native Linux environment, such as WSL, you will need to route your IP and ports to your WSL environment.
Steps:
Create the
.wslconfig
file in your user's home directory in Windows with the following contents:
[wsl2]
localhostforwarding=true
In WSL, find out the
inet
IP address of your WSL container (172.X.X.X
):
sudo apt install net-tools
ifconfig
In Windows (PowerShell), allow traffic to be routed into the WSL container (replace
172.X.X.X
with the IP address (inet
) from step 2):
netsh interface portproxy add v4tov4 listenport=31330 listenaddress=0.0.0.0 connectport=31330 connectaddress=172.X.X.X
Set up your firewall (e.g., Windows Defender) to allow traffic from the outside world to port 31330/tcp.
You can also add an inbound rule for the port in the Windows Defender Firewall with Advanced Security settings.
Type in the Windows search bar "Windows Defender Firewall with Advanced Security"
Go to Inbound Rules
Create a new rule for ports (TCP) 33130 and 33131 to allow traffic
If you have a router, set it up to allow connections from the outside world (port 31330/tcp) to your computer (port 31330/tcp).
This is usually not required
In Windows PowerShell, you can run netsh interface portproxy show all
. The WSL IP and port of choice should be shown if you did this correctly. This is the IP address you will use to start the node.
Start The Subnet
We're going to use a small but powerful 1B parameter model, TinyLlama/TinyLlama-1.1B-Chat-v1.0.
For both nodes, replace the 127.0.0.1
placeholder IP address with your IP address. If you're using WSL, use the IP address we got when running netsh interface portproxy show all
.
The mesh template comes with 2 test RSA private keys in the root, server2.id
and server3.id
.
Start your first node:
We start this node using port 31330. This will be used as the bootstrap node for the second node to connect to.
mesh-server-mock TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host_maddrs /ip4/0.0.0.0/tcp/31330 /ip4/0.0.0.0/udp/31330/quic --announce_maddrs /ip4/127.0.0.1/tcp/31330 /ip4/127.0.0.1/udp/31330/quic --new_swarm --identity_path server2.id --subnet_id 1 --subnet_node_id 1
Keep this node running and open a new CLI tab.
Start the second node (in a separate CLI)
We start this node using port 31331.
mesh-server-mock TinyLlama/TinyLlama-1.1B-Chat-v1.0 --public_ip 127.0.0.1 --port 31331 --identity_path server2.id
We now have a fully decentralized network of nodes hosting models where nodes can call inference on each other.
Last updated