Record Validator

Record validators are a generic mechanism for checking the DHT records, including:

  • Enforcing a data schema (e.g., checking content types)

  • Enforcing security requirements (e.g., allowing only the owner to update the record)

  • Enforcement using predicates/callables for customizable logic on the DHTRecord (e.g., checking keys match allowable keys, checking expiration dates, etc.)

When starting the DHT, it can be initialized with an Iterable of RecordValidatorBase.

It is suggested to use or create your own security requirements for the DHT records.

RecordValidatorBase

The RecordValidatorBase is the abstract base class for all record validators to implement.

class RecordValidatorBase(ABC):
    """
    Record validators are a generic mechanism for checking the DHT records including:
      - Enforcing a data schema (e.g. checking content types)
      - Enforcing security requirements (e.g. allowing only the owner to update the record)
    """

    @abstractmethod
    def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
        """
        Should return whether the `record` is valid based on request type.
        The valid records should have been extended with sign_value().

        validate() is called when another DHT peer:
          - Asks us to store the record
          - Returns the record by our request
        """

        pass

    def sign_value(self, record: DHTRecord) -> bytes:
        """
        Should return `record.value` extended with the record's signature.

        Note: there's no need to overwrite this method if a validator doesn't use a signature.

        sign_value() is called after the application asks the DHT to store the record.
        """

        return record.value

    def strip_value(self, record: DHTRecord) -> bytes:
        """
        Should return `record.value` stripped of the record's signature.
        strip_value() is only called if validate() was successful.

        Note: there's no need to overwrite this method if a validator doesn't use a signature.

        strip_value() is called before the DHT returns the record by the application's request.
        """

        return record.value

    @property
    def priority(self) -> int:
        """
        Defines the order of applying this validator with respect to other validators.

        The validators are applied:
          - In order of increasing priority for signing a record
          - In order of decreasing priority for validating and stripping a record
        """

        return 0

    def merge_with(self, other: "RecordValidatorBase") -> bool:
        """
        By default, all validators are applied sequentially (i.e. we require all validate() calls
        to return True for a record to be validated successfully).

        However, you may want to define another policy for combining your validator classes
        (e.g. for schema validators, we want to require only one validate() call to return True
        because each validator bears a part of the schema).

        This can be achieved with overriding merge_with(). It should:

          - Return True if it has successfully merged the `other` validator to `self`,
            so that `self` became a validator that combines the old `self` and `other` using
            the necessary policy. In this case, `other` should remain unchanged.

          - Return False if the merging has not happened. In this case, both `self` and `other`
            should remain unchanged. The DHT will try merging `other` to another validator or
            add it as a separate validator (to be applied sequentially).
        """

        return False

By default, the template comes with two main Record Validators:

PredicateValidator

A general-purpose DHT validator that delegates all validation logic to a custom predicate.

This is a minimal validator that can enforce any condition on the entire DHTRecord. Useful for filtering keys, expiration time, value content, or any combination thereof.

class PredicateValidator(RecordValidatorBase):
    """
    A general-purpose DHT validator that delegates all validation logic to a custom callable.

    This is a minimal validator that can enforce any condition on the entire DHTRecord.
    Useful for filtering keys, expiration time, value content, or any combination thereof.

    This can be used to ensure keys match a specific format, or nodes are doing something within a certain period
    of time in relation to the blockchain, i.e., ensuring a commit-reveal schema where the commit is submitted by the
    first half of the epoch and the reveal is done on the second half of the epoch.

    Attributes:
        record_predicate (Callable[[DHTRecord], bool]): A user-defined function that receives a record and returns True if valid.
    """

    def __init__(
        self,
        record_predicate: Callable[[DHTRecord], bool] = lambda r: True,
    ):
        self.record_predicate = record_predicate

    def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
        return self.record_predicate(record)

    def sign_value(self, record: DHTRecord) -> bytes:
        return record.value

    def strip_value(self, record: DHTRecord) -> bytes:
        return record.value

    def merge_with(self, other: RecordValidatorBase) -> bool:
        if not isinstance(other, PredicateValidator):
            return False

        # Ignore another KeyValidator instance (it doesn't make sense to have several
        # instances of this class) and report successful merge
        return True

Hypertensor Predicate Validator

Similar to the Predicate Validator, but instead the callable takes in the current epoch data as:

@dataclass
class EpochData:
    block: int
    epoch: int
    block_per_epoch: int
    seconds_per_epoch: int
    percent_complete: float
    blocks_elapsed: int
    blocks_remaining: int
    seconds_elapsed: int
    seconds_remaining: int

This is useful for having conditions based on time, such as for commit-reveal schemes that should be synced with the Hypertensor blockchain clock.

class HypertensorPredicateValidator(RecordValidatorBase):
    """
    A general-purpose DHT validator that delegates all validation logic to a custom callable
    that takes in an epoch data function.

    This is a minimal validator that can enforce any condition on the entire DHTRecord.
    Useful for filtering keys, expiration time, value content, or any combination thereof.

    This can be used to ensure keys match a specific format, or nodes are doing something within a certain period
    of time in relation to the blockchain, i.e., ensuring a commit-reveal schema where the commit is submitted by the
    first half of the epoch and the reveal is done on the second half of the epoch.

    Attributes:
        record_predicate (Callable[[DHTRecord], bool]): A user-defined function that receives a record and returns True if valid.
    """

    def __init__(
        self,
        hypertensor: Hypertensor,
        record_predicate: Callable[[DHTRecord, DHTRequestType], bool] = lambda r: True
    ):
        self.record_predicate = record_predicate
        self.hypertensor = hypertensor

    def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
        return self.record_predicate(record, type, self._epoch_data())

    def sign_value(self, record: DHTRecord) -> bytes:
        return record.value

    def strip_value(self, record: DHTRecord) -> bytes:
        return record.value

    def _epoch_data(self):
        # Get epoch data from the blockchain and calculate the remaining
        return self.hypertensor.get_epoch_progress()

    def merge_with(self, other: RecordValidatorBase) -> bool:
        if not isinstance(other, HypertensorPredicateValidator):
            return False

        # Ignore another KeyValidator instance (it doesn't make sense to have several
        # instances of this class) and report a successful merge
        return True

SchemaValidator

Restricts specified DHT keys to match a Pydantic schema. This allows for enforcing types, min/max values, requiring a subkey to contain a public key, etc.

class SchemaValidator(RecordValidatorBase):
    """
    Restricts specified DHT keys to match a Pydantic schema.
    This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
    """

    def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
        """
        :param schema: The Pydantic model (a subclass of pydantic.BaseModel).

            You must always use strict types for the number fields
            (e.g. ``StrictInt`` instead of ``int``,
            ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
            See the validate() docstring for details.

            The model will be patched to adjust it for the schema validation.

        :param allow_extra_keys: Whether to allow keys that are not defined in the schema.

            If a SchemaValidator is merged with another SchemaValidator, this option applies to
            keys that are not defined in each of the schemas.

        :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
        """

        self._patch_schema(schema)
        self._schemas = [schema]

        self._key_id_to_field_name = {}
        for field in schema.__fields__.values():
            raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name
            self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
        self._allow_extra_keys = allow_extra_keys

    @staticmethod
    def _patch_schema(schema: pydantic.BaseModel):
        # We set required=False because the validate() interface provides only one key at a time
        for field in schema.__fields__.values():
            field.required = False

        schema.Config.extra = pydantic.Extra.forbid

    def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
        """
        Validates ``record`` in two steps:

        1. Create a Pydantic model and ensure that no exceptions are thrown.

        2. Ensure that Pydantic has not made any type conversions [1]_ while creating the model.
           To do this, we check that the value of the model field is equal
           (in terms of == operator) to the source value.

           This works for the iterable default types like str, list, and dict
           (they are equal only if the types match) but does not work for numbers
           (they have a special case allowing ``3.0 == 3`` to be true). [2]_

           Because of that, you must always use strict types [3]_ for the number fields
           (e.g. to avoid ``3.0`` to be validated successfully for the ``field: int``).

           .. [1] https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
           .. [2] https://stackoverflow.com/a/52557261
           .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
        """

        if record.key not in self._key_id_to_field_name:
            if not self._allow_extra_keys:
                logger.debug(
                    f"Record {record} has a key ID that is not defined in any of the "
                    f"schemas (therefore, the raw key is unknown)"
                )
            return self._allow_extra_keys

        try:
            record = self._deserialize_record(record)
        except ValueError as e:
            logger.debug(e)
            return False
        [field_name] = list(record.keys())

        validation_errors = []
        for schema in self._schemas:
            try:
                parsed_record = schema.parse_obj(record)
            except pydantic.ValidationError as e:
                if not self._is_failed_due_to_extra_field(e):
                    validation_errors.append(e)
                continue

            parsed_value = parsed_record.dict(by_alias=True)[field_name]
            if parsed_value != record[field_name]:
                validation_errors.append(
                    ValueError(
                        f"The record {record} needed type conversions to match "
                        f"the schema: {parsed_value}. Type conversions are not allowed"
                    )
                )
            else:
                return True

        logger.debug(f"Record {record} doesn't match any of the schemas: {validation_errors}")
        return False

    def _deserialize_record(self, record: DHTRecord) -> Dict[str, Any]:
        field_name = self._key_id_to_field_name[record.key]
        deserialized_value = DHTProtocol.serializer.loads(record.value)
        if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
            deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
            return {field_name: {deserialized_subkey: deserialized_value}}
        else:
            if isinstance(deserialized_value, dict):
                raise ValueError(
                    f"Record {record} contains an improperly serialized dictionary (you must use "
                    f"a DictionaryDHTValue of serialized values instead of a `dict` subclass)"
                )
            return {field_name: deserialized_value}

    @staticmethod
    def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
        inner_errors = exc.errors()
        return (
            len(inner_errors) == 1
            and inner_errors[0]["type"] == "value_error.extra"
            and len(inner_errors[0]["loc"]) == 1  # Require the extra field to be on the top level
        )

    def merge_with(self, other: RecordValidatorBase) -> bool:
        if not isinstance(other, SchemaValidator):
            return False

        self._schemas.extend(other._schemas)
        self._key_id_to_field_name.update(other._key_id_to_field_name)
        self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
        return True

    def __setstate__(self, state):
        self.__dict__.update(state)

        # If unpickling happens in another process, the previous model modifications may be lost
        for schema in self._schemas:
            self._patch_schema(schema)

SignatureValidator (RSA and Ed25519)

Introduces the notion of protected records whose key/subkey contains the substring "[owner:ssh-ed25519 ...]" with an Ed25519 public key of the owner.

If this validator is used, changes to such records must always be signed with the corresponding private key (so only the owner can change them).

class Ed25519SignatureValidator(RecordValidatorBase):
    """
    Introduces a notion of *protected records* whose key/subkey contains substring
    "[owner:ssh-ed25519 ...]" with an Ed25519 public key of the owner.

    If this validator is used, changes to such records always must be signed with
    the corresponding private key (so only the owner can change them).
    """

    PUBLIC_KEY_FORMAT = b"[owner:_key_]"
    SIGNATURE_FORMAT = b"[signature:_value_]"

    PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b"_key_", rb"(.+?)")
    _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
    _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b"_value_", rb"(.+?)"))

    _cached_private_key = None

    def __init__(self, private_key: Optional[Ed25519PrivateKey] = None):
        if private_key is None:
            private_key = Ed25519PrivateKey.process_wide()
        self._private_key = private_key

        serialized_public_key = private_key.get_public_key().to_bytes()
        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b"_key_", serialized_public_key)

    @property
    def local_public_key(self) -> bytes:
        return self._local_public_key

    def validate(self, record: DHTRecord, type: DHTRequestType) -> bool:
        public_keys = self._PUBLIC_KEY_RE.findall(record.key)
        if record.subkey is not None:
            public_keys += self._PUBLIC_KEY_RE.findall(record.subkey)
        if not public_keys:
            return True  # The record is not protected with a public key

        if len(set(public_keys)) > 1:
            logger.debug(f"Key and subkey can't contain different public keys in {record}")
            return False
        public_key = Ed25519PublicKey.from_bytes(public_keys[0])

        signatures = self._SIGNATURE_RE.findall(record.value)
        if len(signatures) != 1:
            logger.debug(f"Record should have exactly one signature in {record}")
            return False
        signature = signatures[0]

        stripped_record = dataclasses.replace(record, value=self.strip_value(record))
        if not public_key.verify(self._serialize_record(stripped_record), signature):
            logger.debug(f"Signature is invalid in {record}")
            return False
        return True

    def sign_value(self, record: DHTRecord) -> bytes:
        if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
            return record.value

        signature = self._private_key.sign(self._serialize_record(record))
        return record.value + self.SIGNATURE_FORMAT.replace(b"_value_", signature)

    def strip_value(self, record: DHTRecord) -> bytes:
        return self._SIGNATURE_RE.sub(b"", record.value)

    def _serialize_record(self, record: DHTRecord) -> bytes:
        return MSGPackSerializer.dumps(dataclasses.astuple(record))

    @property
    def priority(self) -> int:
        # On validation, this validator must be executed before validators
        # that deserialize the record
        return 10

    def merge_with(self, other: RecordValidatorBase) -> bool:
        if not isinstance(other, Ed25519SignatureValidator):
            return False

        # Ignore another Ed25519SignatureValidator instance (it doesn't make sense to have several
        # instances of this class) and report successful merge
        return True

Last updated