How To Use

The following is a minimal working example Server that comes with the Mesh Template.

from __future__ import annotations

import threading
import time
from typing import Dict, List, Optional

import mesh
from mesh import DHT, get_dht_time
from mesh.dht.crypto import SignatureValidator
from mesh.dht.validation import HypertensorPredicateValidator, RecordValidatorBase
from mesh.subnet.consensus.consensus import Consensus
from mesh.subnet.utils.mock_commit_reveal import MockHypertensorCommitReveal
from mesh.substrate.chain_functions import Hypertensor
from mesh.substrate.mock.chain_functions import MockHypertensor
from mesh.utils.authorizers.auth import SignatureAuthorizer
from mesh.utils.authorizers.pos_auth_v2 import ProofOfStakeAuthorizer
from mesh.utils.data_structures import ServerClass, ServerInfo, ServerState
from mesh.utils.dht import declare_node_sig, get_node_infos_sig
from mesh.utils.key import get_private_key
from mesh.utils.logging import get_logger
from mesh.utils.ping import PingAggregator
from mesh.utils.proof_of_stake import ProofOfStake
from mesh.utils.random import sample_up_to
from mesh.utils.reachability import ReachabilityProtocol, check_direct_reachability
from mesh.utils.timed_storage import MAX_DHT_TIME_DISCREPANCY_SECONDS

logger = get_logger(__name__)

DEFAULT_NUM_WORKERS = 8

class Server:
    def __init__(
        self,
        *,
        initial_peers: List[str],
        public_name: Optional[str] = None,
        role: ServerClass,
        update_period: float = 60,
        expiration: Optional[float] = None,
        reachable_via_relay: Optional[bool] = None,
        use_relay: bool = True,
        use_auto_relay: bool = True,
        subnet_id: Optional[int] = None,
        subnet_node_id: Optional[int] = None,
        hypertensor: Optional[Hypertensor] = None,
        **kwargs,
    ):
        """
        Create a server
        """
        self.reachability_protocol = None
        self.update_period = update_period
        if expiration is None:
            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
        self.expiration = expiration

        self.initial_peers = initial_peers
        self.announce_maddrs = kwargs.get('announce_maddrs')  # Returns None if 'my_key' not present

        self.subnet_id = subnet_id
        self.subnet_node_id = subnet_node_id
        self.hypertensor = hypertensor

        identity_path = kwargs.get('identity_path', None)
        pk = get_private_key(identity_path)

        """
        Initialize record validators

        See https://docs.hypertensor.org/mesh-template/dht-records/record-validator
        """
        # Initialize signature record validator. See https://docs.hypertensor.org/mesh-template/dht-records/record-validator/signature-validators
        self.signature_validator = SignatureValidator(pk)
        self.record_validators=[self.signature_validator]

        # Initialize predicate validator here. See https://docs.hypertensor.org/mesh-template/dht-records/record-validator/predicate-validators
        if self.hypertensor is not None:
            consensus_predicate = HypertensorPredicateValidator.from_predicate_class(
                MockHypertensorCommitReveal, hypertensor=self.hypertensor, subnet_id=subnet_id
            )

        else:
            consensus_predicate = HypertensorPredicateValidator.from_predicate_class(
                MockHypertensorCommitReveal, hypertensor=MockHypertensor(), subnet_id=subnet_id
            )

        self.record_validators.append(consensus_predicate)

        """
        Initialize authorizers

        See https://docs.hypertensor.org/mesh-template/authorizers
        """
        # Initialize signature authorizer. See https://docs.hypertensor.org/mesh-template/authorizers/signature-authorizer
        self.signature_authorizer = SignatureAuthorizer(pk)

        # Initialize PoS authorizer. See https://docs.hypertensor.org/mesh-template/authorizers/pos
        if self.hypertensor is not None:
            logger.info("Initializing PoS - proof-of-stake")
            pos = ProofOfStake(
                self.subnet_id,
                self.hypertensor,
                min_class=1,
            )
            self.pos_authorizer = ProofOfStakeAuthorizer(self.signature_authorizer, pk, pos)
        else:
            logger.info("Skipping PoS - proof-of-stake, using signature authorization only. If starting in production, make sure to use PoS")
            # For testing purposes, at minimum require signatures
            self.pos_authorizer = self.signature_authorizer

        # Test connecting to the DHT as a direct peer
        if reachable_via_relay is None:
            is_reachable = check_direct_reachability(initial_peers=initial_peers, authorizer=self.pos_authorizer, use_relay=False, **kwargs)
            reachable_via_relay = is_reachable is False  # if can't check reachability (returns None), run a full peer
            logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")

        logger.info("About to run DHT")

        self.dht = DHT(
            initial_peers=initial_peers,
            start=True,
            num_workers=DEFAULT_NUM_WORKERS,
            use_relay=use_relay,
            use_auto_relay=use_auto_relay,
            client_mode=reachable_via_relay,
            record_validators=self.record_validators,
            **dict(kwargs, authorizer=self.pos_authorizer)
        )
        self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht, identity_path) if not reachable_via_relay else None

        visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]

        logger.info(f"Running a server on {visible_maddrs_str}")

        throughput_info = {"throughput": 1.0}
        self.server_info = ServerInfo(
            state=ServerState.JOINING,
            role=role,
            public_name=public_name,
            version="1.0.0",
            using_relay=reachable_via_relay,
            **throughput_info,
        )

        self.mock_protocol = None
        self.module_container = None
        self.consensus = None
        self.stop = threading.Event()

    def run(self):
        """
        Start protocols here

        self.protocol = MockProtocol(dht=self.dht)
        """
        self.module_container = ModuleAnnouncerThread(
            dht=self.dht,
            server_info=self.server_info,
            record_validator=self.signature_validator,
            update_period=self.update_period,
            expiration=self.expiration,
            start=True
        )

        self.consensus = ConsensusThread(
            dht=self.dht,
            server_info=self.server_info,
            subnet_id=self.subnet_id,
            subnet_node_id=self.subnet_node_id,
            record_validator=self.signature_validator,
            hypertensor=self.hypertensor,
            start=True
        )

        """
        Keep server running forever
        """
        self.stop.wait()

    def shutdown(self, timeout: Optional[float] = 5):
        logger.info("Shutting down Server, wait to shutdown properly")
        self.stop.set()

        if self.mock_protocol is not None:
            self.mock_protocol.shutdown()

        if self.reachability_protocol is not None:
            self.reachability_protocol.shutdown()

        if self.consensus is not None:
            self.consensus.shutdown()

        self.dht.shutdown()
        self.dht.join()

class ModuleAnnouncerThread(threading.Thread):
    def __init__(
        self,
        dht: DHT,
        server_info: ServerInfo,
        record_validator: RecordValidatorBase,
        update_period: float,
        expiration: Optional[float] = None,
        start: bool = True,
    ):
        super().__init__()
        self.dht = dht

        server_info.state = ServerState.JOINING
        self.dht_announcer = ModuleHeartbeatThread(
            dht,
            server_info,
            record_validator,
            update_period=update_period,
            expiration=expiration,
            daemon=True,
        )
        self.role = server_info.role
        self.dht_announcer.start()
        logger.info("Announced to the DHT that we are joining")

        if start:
            self.start()

    def run(self):
        logger.info("Announcing that node is online")
        self.dht_announcer.announce(ServerState.ONLINE)

    def shutdown(self):
        """
        Gracefully terminate the container, process-safe.
        """
        self.dht_announcer.announce(ServerState.OFFLINE)
        logger.info("Announced to the DHT that we are exiting")

        # self.join()
        # logger.info("Module shut down successfully")

        if self.is_alive() and threading.current_thread() is not self:
            self.join(timeout=5)
        logger.info("Module shut down successfully")

class ConsensusThread():
    def __init__(
        self,
        dht: DHT,
        server_info: ServerInfo,
        subnet_id: int,
        subnet_node_id: int,
        record_validator: RecordValidatorBase,
        hypertensor: Hypertensor,
        start: bool = True,
    ):
        super().__init__()
        self.dht = dht
        self.server_info = server_info
        self.subnet_id = subnet_id
        self.subnet_node_id = subnet_node_id
        self.signature_validator = record_validator
        self.hypertensor = hypertensor
        self.consensus = None
        self.validator = None

        if start:
            self.run()

    def run(self) -> None:
        """
        Add any other logic the Consensus class requires to run,
        such as differ node role classes, etc.

        See template implementation
        """

        self.consensus = Consensus(
            dht=self.dht,
            subnet_id=self.subnet_id,
            subnet_node_id=self.subnet_node_id,
            record_validator=self.signature_validator,
            hypertensor=self.hypertensor,
            skip_activate_subnet=False,
            start=True,
        )

        logger.info("Starting consensus")

    def shutdown(self):
        if self.consensus is not None:
            self.consensus.shutdown()

        if self.validator is not None:
            self.validator.shutdown()

class ModuleHeartbeatThread(threading.Thread):
    """Periodically announces server is live before expiration of storage, visible to all DHT peers"""

    def __init__(
        self,
        dht: DHT,
        server_info: ServerInfo,
        record_validator: RecordValidatorBase,
        *,
        update_period: float,
        expiration: float,
        max_pinged: int = 5,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dht = dht
        self.server_info = server_info
        self.record_validator = record_validator

        self.update_period = update_period
        self.expiration = expiration
        self.trigger = threading.Event()

        self.max_pinged = max_pinged
        self.ping_aggregator = PingAggregator(self.dht)

    def run(self) -> None:
        """
        Start heartbeat

        - Tell the network you're still hear
        - Ping other nodes
        """
        while True:
            start_time = time.perf_counter()

            if self.server_info.state != ServerState.OFFLINE:
                self._ping_next_servers()
                self.server_info.next_pings = {
                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
                }
            else:
                self.server_info.next_pings = None  # No need to ping if we're disconnecting

            logger.info("Declaring node [Heartbeat]...")

            """
            Do not change the "node" key

            See https://docs.hypertensor.org/build-a-subnet/requirements#node-key-public-key-subkey
            """
            declare_node_sig(
                dht=self.dht,
                key="node",
                server_info=self.server_info,
                expiration_time=get_dht_time() + self.expiration,
                record_validator=self.record_validator
            )

            if self.server_info.state == ServerState.OFFLINE:
                break

            """
            If you want to host multiple applications in one DHT or run a bootstrap node that acts as an entry 
            point to multiple subnets, you can do so in the DHTStorage mechanism.

            Without a clear understanding of how DHTs or DHTStorage, we suggest isolating subnets and not using this.

            if not self.dht_prefix.startswith("_"):
                self.dht.store(
                    key="_team_name_here.subnets",
                    subkey=self.dht_prefix,
                    value=self.model_info.to_dict(),
                    expiration_time=get_dht_time() + self.expiration,
                )
            """

            delay = self.update_period - (time.perf_counter() - start_time)
            if delay < 0:
                logger.warning(
                    f"Declaring node to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
                )
            self.trigger.wait(max(delay, 0))
            self.trigger.clear()

    def announce(self, state: ServerState) -> None:
        self.server_info.state = state
        self.trigger.set()
        if state == ServerState.OFFLINE:
            if self.is_alive():
                self.join(timeout=5)


    def _ping_next_servers(self) -> Dict[mesh.PeerID, float]:
        module_infos = get_node_infos_sig(
            self.dht,
            uid="node",
            latest=True
        )
        if len(module_infos) == 0:
            return
        middle_servers = {info.peer_id for info in module_infos}
        pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
        # discard self
        pinged_servers.discard(self.dht.peer_id)
        self.ping_aggregator.ping(list(pinged_servers))

Last updated