Record Validator
Record validators are a generic mechanism for checking the DHT records, including:
Enforcing a data schema (e.g., checking content types)
Enforcing security requirements (e.g., allowing only the owner to update the record)
Enforcement using predicates/callables for customizable logic on the DHTRecord (e.g., checking keys match allowable keys, checking expiration dates, etc.)
When starting the DHT, it can be initialized with an Iterable of RecordValidatorBase.
It is suggested to use or create your own security requirements for the DHT records.
RecordValidatorBase
The RecordValidatorBase is the abstract base class for all record validators to implement.
class RecordValidatorBase(ABC):
"""
Record validators are a generic mechanism for checking the DHT records including:
- Enforcing a data schema (e.g. checking content types)
- Enforcing security requirements (e.g. allowing only the owner to update the record)
"""
@abstractmethod
def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
"""
Should return whether the `record` is valid based on request type.
The valid records should have been extended with sign_value().
validate() is called when another DHT peer:
- Asks us to store the record
- Returns the record by our request
"""
pass
def sign_value(self, record: DHTRecord) -> bytes:
"""
Should return `record.value` extended with the record's signature.
Note: there's no need to overwrite this method if a validator doesn't use a signature.
sign_value() is called after the application asks the DHT to store the record.
"""
return record.value
def strip_value(self, record: DHTRecord) -> bytes:
"""
Should return `record.value` stripped of the record's signature.
strip_value() is only called if validate() was successful.
Note: there's no need to overwrite this method if a validator doesn't use a signature.
strip_value() is called before the DHT returns the record by the application's request.
"""
return record.value
@property
def priority(self) -> int:
"""
Defines the order of applying this validator with respect to other validators.
The validators are applied:
- In order of increasing priority for signing a record
- In order of decreasing priority for validating and stripping a record
"""
return 0
def merge_with(self, other: "RecordValidatorBase") -> bool:
"""
By default, all validators are applied sequentially (i.e. we require all validate() calls
to return True for a record to be validated successfully).
However, you may want to define another policy for combining your validator classes
(e.g. for schema validators, we want to require only one validate() call to return True
because each validator bears a part of the schema).
This can be achieved with overriding merge_with(). It should:
- Return True if it has successfully merged the `other` validator to `self`,
so that `self` became a validator that combines the old `self` and `other` using
the necessary policy. In this case, `other` should remain unchanged.
- Return False if the merging has not happened. In this case, both `self` and `other`
should remain unchanged. The DHT will try merging `other` to another validator or
add it as a separate validator (to be applied sequentially).
"""
return False
By default, the template comes with two main Record Validators:
PredicateValidator
A general-purpose DHT validator that delegates all validation logic to a custom predicate.
This is a minimal validator that can enforce any condition on the entire DHTRecord. Useful for filtering keys, expiration time, value content, or any combination thereof.
class PredicateValidator(RecordValidatorBase):
"""
A general-purpose DHT validator that delegates all validation logic to a custom callable.
This is a minimal validator that can enforce any condition on the entire DHTRecord.
Useful for filtering keys, expiration time, value content, or any combination thereof.
This can be used to ensure keys match a specific format, or nodes are doing something within a certain period
of time in relation to the blockchain, i.e., ensuring a commit-reveal schema where the commit is submitted by the
first half of the epoch and the reveal is done on the second half of the epoch.
Attributes:
record_predicate (Callable[[DHTRecord], bool]): A user-defined function that receives a record and returns True if valid.
"""
def __init__(
self,
record_predicate: Callable[[DHTRecord], bool] = lambda r: True,
):
self.record_predicate = record_predicate
def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
return self.record_predicate(record)
def sign_value(self, record: DHTRecord) -> bytes:
return record.value
def strip_value(self, record: DHTRecord) -> bytes:
return record.value
def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, PredicateValidator):
return False
# Ignore another KeyValidator instance (it doesn't make sense to have several
# instances of this class) and report successful merge
return True
Hypertensor Predicate Validator
Similar to the Predicate Validator, but instead the callable takes in the current epoch data as:
@dataclass
class EpochData:
block: int
epoch: int
block_per_epoch: int
seconds_per_epoch: int
percent_complete: float
blocks_elapsed: int
blocks_remaining: int
seconds_elapsed: int
seconds_remaining: int
This is useful for having conditions based on time, such as for commit-reveal schemes that should be synced with the Hypertensor blockchain clock.
class HypertensorPredicateValidator(RecordValidatorBase):
"""
A general-purpose DHT validator that delegates all validation logic to a custom callable
that takes in an epoch data function.
This is a minimal validator that can enforce any condition on the entire DHTRecord.
Useful for filtering keys, expiration time, value content, or any combination thereof.
This can be used to ensure keys match a specific format, or nodes are doing something within a certain period
of time in relation to the blockchain, i.e., ensuring a commit-reveal schema where the commit is submitted by the
first half of the epoch and the reveal is done on the second half of the epoch.
Attributes:
record_predicate (Callable[[DHTRecord], bool]): A user-defined function that receives a record and returns True if valid.
"""
def __init__(
self,
hypertensor: Hypertensor,
record_predicate: Callable[[DHTRecord, DHTRequestType], bool] = lambda r: True
):
self.record_predicate = record_predicate
self.hypertensor = hypertensor
def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
return self.record_predicate(record, type, self._epoch_data())
def sign_value(self, record: DHTRecord) -> bytes:
return record.value
def strip_value(self, record: DHTRecord) -> bytes:
return record.value
def _epoch_data(self):
# Get epoch data from the blockchain and calculate the remaining
return self.hypertensor.get_epoch_progress()
def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, HypertensorPredicateValidator):
return False
# Ignore another KeyValidator instance (it doesn't make sense to have several
# instances of this class) and report a successful merge
return True
SchemaValidator
Restricts specified DHT keys to match a Pydantic schema. This allows for enforcing types, min/max values, requiring a subkey to contain a public key, etc.
class SchemaValidator(RecordValidatorBase):
"""
Restricts specified DHT keys to match a Pydantic schema.
This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
"""
def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
"""
:param schema: The Pydantic model (a subclass of pydantic.BaseModel).
You must always use strict types for the number fields
(e.g. ``StrictInt`` instead of ``int``,
``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
See the validate() docstring for details.
The model will be patched to adjust it for the schema validation.
:param allow_extra_keys: Whether to allow keys that are not defined in the schema.
If a SchemaValidator is merged with another SchemaValidator, this option applies to
keys that are not defined in each of the schemas.
:param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
"""
self._patch_schema(schema)
self._schemas = [schema]
self._key_id_to_field_name = {}
for field in schema.__fields__.values():
raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name
self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
self._allow_extra_keys = allow_extra_keys
@staticmethod
def _patch_schema(schema: pydantic.BaseModel):
# We set required=False because the validate() interface provides only one key at a time
for field in schema.__fields__.values():
field.required = False
schema.Config.extra = pydantic.Extra.forbid
def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
"""
Validates ``record`` in two steps:
1. Create a Pydantic model and ensure that no exceptions are thrown.
2. Ensure that Pydantic has not made any type conversions [1]_ while creating the model.
To do this, we check that the value of the model field is equal
(in terms of == operator) to the source value.
This works for the iterable default types like str, list, and dict
(they are equal only if the types match) but does not work for numbers
(they have a special case allowing ``3.0 == 3`` to be true). [2]_
Because of that, you must always use strict types [3]_ for the number fields
(e.g. to avoid ``3.0`` to be validated successfully for the ``field: int``).
.. [1] https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
.. [2] https://stackoverflow.com/a/52557261
.. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
"""
if record.key not in self._key_id_to_field_name:
if not self._allow_extra_keys:
logger.debug(
f"Record {record} has a key ID that is not defined in any of the "
f"schemas (therefore, the raw key is unknown)"
)
return self._allow_extra_keys
try:
record = self._deserialize_record(record)
except ValueError as e:
logger.debug(e)
return False
[field_name] = list(record.keys())
validation_errors = []
for schema in self._schemas:
try:
parsed_record = schema.parse_obj(record)
except pydantic.ValidationError as e:
if not self._is_failed_due_to_extra_field(e):
validation_errors.append(e)
continue
parsed_value = parsed_record.dict(by_alias=True)[field_name]
if parsed_value != record[field_name]:
validation_errors.append(
ValueError(
f"The record {record} needed type conversions to match "
f"the schema: {parsed_value}. Type conversions are not allowed"
)
)
else:
return True
logger.debug(f"Record {record} doesn't match any of the schemas: {validation_errors}")
return False
def _deserialize_record(self, record: DHTRecord) -> Dict[str, Any]:
field_name = self._key_id_to_field_name[record.key]
deserialized_value = DHTProtocol.serializer.loads(record.value)
if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
return {field_name: {deserialized_subkey: deserialized_value}}
else:
if isinstance(deserialized_value, dict):
raise ValueError(
f"Record {record} contains an improperly serialized dictionary (you must use "
f"a DictionaryDHTValue of serialized values instead of a `dict` subclass)"
)
return {field_name: deserialized_value}
@staticmethod
def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
inner_errors = exc.errors()
return (
len(inner_errors) == 1
and inner_errors[0]["type"] == "value_error.extra"
and len(inner_errors[0]["loc"]) == 1 # Require the extra field to be on the top level
)
def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, SchemaValidator):
return False
self._schemas.extend(other._schemas)
self._key_id_to_field_name.update(other._key_id_to_field_name)
self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
return True
def __setstate__(self, state):
self.__dict__.update(state)
# If unpickling happens in another process, the previous model modifications may be lost
for schema in self._schemas:
self._patch_schema(schema)
SignatureValidator (RSA and Ed25519)
Introduces the notion of protected records whose key/subkey contains the substring "[owner:ssh-ed25519 ...]" with an Ed25519 public key of the owner.
If this validator is used, changes to such records must always be signed with the corresponding private key (so only the owner can change them).
class Ed25519SignatureValidator(RecordValidatorBase):
"""
Introduces a notion of *protected records* whose key/subkey contains substring
"[owner:ssh-ed25519 ...]" with an Ed25519 public key of the owner.
If this validator is used, changes to such records always must be signed with
the corresponding private key (so only the owner can change them).
"""
PUBLIC_KEY_FORMAT = b"[owner:_key_]"
SIGNATURE_FORMAT = b"[signature:_value_]"
PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b"_key_", rb"(.+?)")
_PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
_SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b"_value_", rb"(.+?)"))
_cached_private_key = None
def __init__(self, private_key: Optional[Ed25519PrivateKey] = None):
if private_key is None:
private_key = Ed25519PrivateKey.process_wide()
self._private_key = private_key
serialized_public_key = private_key.get_public_key().to_bytes()
self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b"_key_", serialized_public_key)
@property
def local_public_key(self) -> bytes:
return self._local_public_key
def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
public_keys = self._PUBLIC_KEY_RE.findall(record.key)
if record.subkey is not None:
public_keys += self._PUBLIC_KEY_RE.findall(record.subkey)
if not public_keys:
return True # The record is not protected with a public key
if len(set(public_keys)) > 1:
logger.debug(f"Key and subkey can't contain different public keys in {record}")
return False
public_key = Ed25519PublicKey.from_bytes(public_keys[0])
signatures = self._SIGNATURE_RE.findall(record.value)
if len(signatures) != 1:
logger.debug(f"Record should have exactly one signature in {record}")
return False
signature = signatures[0]
stripped_record = dataclasses.replace(record, value=self.strip_value(record))
if not public_key.verify(self._serialize_record(stripped_record), signature):
logger.debug(f"Signature is invalid in {record}")
return False
return True
def sign_value(self, record: DHTRecord) -> bytes:
if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
return record.value
signature = self._private_key.sign(self._serialize_record(record))
return record.value + self.SIGNATURE_FORMAT.replace(b"_value_", signature)
def strip_value(self, record: DHTRecord) -> bytes:
return self._SIGNATURE_RE.sub(b"", record.value)
def _serialize_record(self, record: DHTRecord) -> bytes:
return MSGPackSerializer.dumps(dataclasses.astuple(record))
@property
def priority(self) -> int:
# On validation, this validator must be executed before validators
# that deserialize the record
return 10
def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, Ed25519SignatureValidator):
return False
# Ignore another Ed25519SignatureValidator instance (it doesn't make sense to have several
# instances of this class) and report successful merge
return True
Last updated