How To Use
The Server class can look something like this as a minimal working example:
class Server:
def __init__(
self,
*,
initial_peers: List[str],
converted_model_name_or_path: str,
public_name: Optional[str] = None,
role: ServerClass,
update_period: float = 60,
expiration: Optional[float] = None,
skip_reachability_check: bool = False,
reachable_via_relay: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
subnet_id: Optional[int] = None,
subnet_node_id: Optional[int] = None,
hypertensor: Optional[Hypertensor] = None,
**kwargs,
):
"""
Create a server
"""
self.update_period = update_period
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.expiration = expiration
self.converted_model_name_or_path = converted_model_name_or_path
self.initial_peers = initial_peers
self.announce_maddrs = kwargs.get('announce_maddrs') # Returns None if 'my_key' not present
self.subnet_id = subnet_id
self.subnet_node_id = subnet_node_id
self.hypertensor = hypertensor
# Connect to DHT
if reachable_via_relay is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
identity_path = kwargs.get('identity_path', None)
pk = get_rsa_private_key(identity_path)
self.rsa_signature_validator = RSASignatureValidator(pk)
self.record_validators=[self.rsa_signature_validator]
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=DEFAULT_NUM_WORKERS,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=reachable_via_relay,
record_validators=self.record_validators,
**kwargs,
# **dict(kwargs, authorizer=authorizer)
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
logger.info(f"Running a server on {visible_maddrs_str}")
throughput_info = {"throughput": 1.0}
self.server_info = ServerInfo(
state=ServerState.JOINING,
role=role,
public_name=public_name,
version="1.0.0",
using_relay=reachable_via_relay,
**throughput_info,
)
self.inference_protocol = None
self.module_container = None
self.consensus = None
self.stop = threading.Event()
def run(self):
"""
Start protocols here
self.protocol = MockProtocol(dht=self.dht)
"""
self.mock_protocol = MockProtocol(
dht=self.dht,
subnet_id=self.subnet_id,
hypertensor=self.hypertensor,
authorizer=None,
start=True
)
self.module_container = ModuleAnnouncerThread(
dht=self.dht,
server_info=self.server_info,
update_period=self.update_period,
expiration=self.expiration,
start=True
)
self.consensus = ConsensusThread(
dht=self.dht,
server_info=self.server_info,
subnet_id=self.subnet_id,
subnet_node_id=self.subnet_node_id,
record_validator=self.rsa_signature_validator,
hypertensor=self.hypertensor,
start=True
)
"""
Keep server running forever
"""
self.stop.wait()
def shutdown(self, timeout: Optional[float] = 5):
logger.info("Shutting down Server, wait to shutdown properly")
self.stop.set()
if self.mock_protocol is not None:
self.mock_protocol.shutdown()
if self.reachability_protocol is not None:
self.reachability_protocol.shutdown()
if self.consensus is not None:
self.consensus.shutdown()
self.dht.shutdown()
self.dht.join()
class ModuleAnnouncerThread(threading.Thread):
def __init__(
self,
dht: DHT,
server_info: ServerInfo,
update_period: float,
expiration: Optional[float] = None,
start: bool = True,
):
super().__init__()
self.dht = dht
server_info.state = ServerState.JOINING
self.dht_announcer = ModuleHeartbeatThread(
dht,
server_info,
update_period=update_period,
expiration=expiration,
daemon=True,
)
self.role = server_info.role
self.dht_announcer.start()
logger.info("Announced to the DHT that we are joining")
if start:
self.run()
def run(self):
logger.info("Announcing node is online")
self.dht_announcer.announce(ServerState.ONLINE)
def shutdown(self):
"""
Gracefully terminate the container, process-safe.
"""
self.dht_announcer.announce(ServerState.OFFLINE)
logger.info("Announced to the DHT that we are exiting")
self.join()
logger.info("Module shut down successfully")
class ConsensusThread(threading.Thread):
def __init__(
self,
dht: DHT,
server_info: ServerInfo,
subnet_id: int,
subnet_node_id: int,
record_validator: RecordValidatorBase,
hypertensor: Hypertensor,
start: bool = True,
):
super().__init__()
self.dht = dht
self.server_info = server_info
self.subnet_id = subnet_id
self.subnet_node_id = subnet_node_id
self.rsa_signature_validator = record_validator
self.hypertensor = hypertensor
if start:
self.run()
def run(self) -> None:
"""
Add any other logic the Consensus class requires to run,
such as differ node role classes, etc.
self.validator = Validator(
role=self.server_info.role,
dht=self.dht,
record_validator=self.rsa_signature_validator,
hypertensor=self.hypertensor,
)
See template implementation
"""
self.consensus = Consensus(
dht=self.dht,
subnet_id=self.subnet_id,
subnet_node_id=self.subnet_node_id,
role=self.server_info.role,
record_validator=self.rsa_signature_validator,
hypertensor=self.hypertensor,
validator=self.validator,
start=True,
)
def shutdown(self):
if self.consensus is not None:
self.consensus.shutdown()
if self.validator is not None:
self.validator.shutdown()
self.join()
class ModuleHeartbeatThread(threading.Thread):
"""Periodically announces server is live before expiration of storage, visible to all DHT peers"""
def __init__(
self,
dht: DHT,
server_info: ServerInfo,
*,
update_period: float,
expiration: float,
max_pinged: int = 5,
**kwargs,
):
super().__init__(**kwargs)
self.dht = dht
self.server_info = server_info
self.update_period = update_period
self.expiration = expiration
self.trigger = threading.Event()
self.max_pinged = max_pinged
self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None:
"""
Start heartbeat
- Tell the network you're still here
- Ping other nodes
"""
while True:
start_time = time.perf_counter()
if self.server_info.state != ServerState.OFFLINE:
self._ping_next_servers()
self.server_info.next_pings = {
peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
}
print("self.server_info.next_pings", self.server_info.next_pings)
else:
self.server_info.next_pings = None # No need to ping if we're disconnecting
declare_node(
dht=self.dht,
key="validator",
server_info=self.server_info,
expiration_time=get_dht_time() + self.expiration,
)
if self.server_info.state == ServerState.OFFLINE:
break
delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0:
logger.warning(
f"Declaring node to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
)
self.trigger.wait(max(delay, 0))
self.trigger.clear()
def announce(self, state: ServerState) -> None:
self.server_info.state = state
self.trigger.set()
if state == ServerState.OFFLINE:
self.join()
def _ping_next_servers(self) -> Dict[mesh.PeerID, float]:
module_infos = get_node_infos(
self.dht,
uid="hoster",
latest=True
)
if len(module_infos) == 0:
return
middle_servers = {info.peer_id for info in module_infos}
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
# discard self
pinged_servers.discard(self.dht.peer_id)
self.ping_aggregator.ping(list(pinged_servers))
Last updated