Build an Inference Subnet
What You’ll Achieve
Set Up Your Development Environment
git clone https://github.com/hypertensor-blockchain/subnet-template.gitimport asyncio
import threading
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, Union
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
TextIteratorStreamer,
)
from transformers import logging as transformers_logging
from subnet.utils.logging import get_logger
transformers_logging.set_verbosity_info()
logger = get_logger(__name__)
"""
Used internally for validation
"""
def set_seed(seed: int = 42):
import random, numpy as np # noqa: E401, I001
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
class InferenceModel:
def __init__(
self,
model_name_or_path: str,
device: Optional[Union[str, torch.device]] = None,
):
logger.info(f"self.device {self.device}")
self.model_name_or_path = model_name_or_path
logger.info(f"Loading {self.model_name_or_path} tokenizer...")
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_name_or_path)
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
self.model.eval()
async def stream_infer(
self,
inputs: torch.Tensor,
max_new_tokens: int = 5,
do_sample: bool = True
) -> AsyncIterator[torch.Tensor]:
"""
Stream one token at a time as they're generated by the model in real time.
"""
print("stream_infer.")
logger.info("stream_infer.")
queue: asyncio.Queue[Optional[torch.Tensor]] = asyncio.Queue()
loop = asyncio.get_event_loop()
# Move inputs to device
inputs = inputs.to(self.device)
attention_mask = torch.ones_like(inputs, dtype=torch.long)
# Setup TextStreamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=False,
)
# Generation thread
def generate():
try:
with torch.no_grad():
self.model.generate(
input_ids=inputs,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=do_sample,
use_cache=True,
streamer=streamer,
)
except Exception as e:
logger.exception(f"Exception during model.generate: {e}", exc_info=True)
loop.call_soon_threadsafe(queue.put_nowait, None)
# Streaming thread
def stream_worker():
try:
for decoded_token in streamer:
token_ids = self.tokenizer.encode(decoded_token, add_special_tokens=False)
for tid in token_ids:
print("tid", tid)
loop.call_soon_threadsafe(queue.put_nowait, torch.tensor([tid]))
except Exception as e:
logger.exception(f"Exception in stream_worker: {e}", exc_info=True)
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)
# Start both threads
threading.Thread(target=generate, daemon=True).start()
threading.Thread(target=stream_worker, daemon=True).start()
# Yield tokens as they arrive
while True:
token = await queue.get()
if token is None:
break
logger.debug(f"Yielding token: {token}")
yield token
@dataclass
class InferenceTask:
tensors: torch.Tensor
future: asyncio.Future
@dataclass(order=True)
class PrioritizedTask:
priority: int
count: int # Tie-breaker to ensure FIFO order for same priority
task: InferenceTask = field(compare=False)
class AsyncInferenceServer:
def __init__(self, model: InferenceModel):
self.model = model
self.queue: asyncio.PriorityQueue[PrioritizedTask] = asyncio.PriorityQueue()
self._counter = 0 # For FIFO ordering of equal-priority tasks
async def start(self):
asyncio.create_task(self._worker())
async def submit(self, tensors: torch.Tensor, priority: int = 10) -> AsyncIterator[torch.Tensor]:
"""
Submit an inference task to the queue with an optional priority level.
This method packages a tensor input into an `InferenceTask` and places it
into an internal priority queue for asynchronous processing. Tasks with
lower priority values (e.g., 0) are processed before those with higher
values (e.g., 10). This allows important tasks (e.g., internal validator
jobs) to cut ahead of less urgent ones.
Args:
tensors (torch.Tensor): The input tensor to be processed by the model.
priority (int, optional): A numeric priority value for the task. Lower
values mean higher priority. Defaults to 10.
Returns:
AsyncIterator[torch.Tensor]: An asynchronous iterator that yields
inference outputs from the model.
Example:
stream = await model.submit(my_input_tensor, priority=5)
async for token in stream:
print(token)
"""
logger.debug(f"Task submitted to async inference server, priority={priority}")
print(f"Task submitted to async inference server, priority={priority}")
future = asyncio.get_event_loop().create_future()
task = InferenceTask(tensors=tensors, future=future)
self._counter += 1
prioritized_task = PrioritizedTask(priority=priority, count=self._counter, task=task)
await self.queue.put(prioritized_task)
return await future
async def _worker(self):
logger.debug("Priority queue worker started...")
while True:
try:
prioritized_task: PrioritizedTask = await self.queue.get()
task = prioritized_task.task
logger.debug(f"Processing task with priority {prioritized_task.priority}")
try:
stream = self.model.stream_infer(task.tensors)
task.future.set_result(stream)
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
task.future.set_exception(e)
finally:
self.queue.task_done()
except Exception as e:
logger.error(f"Worker error: {e}", exc_info=True)
Starting Your Protocol
Setting up your environment
Start The Subnet
Start your first node:
Start the second node (in a separate CLI)
Last updated