Source code for diffusion_layer_model

"""
Library for the DiffusionLayerModel class.

It contains the Python model used to verify the Diffusion Layer module.

@author: Timothée Charrier
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from cocotb.handle import HierarchyObject


[docs] class DiffusionLayerModel: """ Model for the Diffusion Layer module. This class defines the model used to verify the Diffusion Layer module. """ def __init__( self, ) -> None: """Initialize the model.""" # Output state self.o_state: list[int] = [0] * 5
[docs] @staticmethod def rotate_right(value: int, num_bits: int) -> int: """ Rotate the bits of a 64-bit integer to the right. Parameters ---------- value : int The input value. num_bits : int The number of bits to rotate. Returns ------- int The rotated value. """ return (value >> num_bits) | ((value & (1 << num_bits) - 1) << (64 - num_bits))
def _linear_diffusion_layer(self, state: list[int]) -> list[int]: """ Apply the linear diffusion layer. Parameters ---------- state : List[int] The current state. Returns ------- List[int] The updated state after the linear diffusion layer. """ rotations: list[tuple[int, list[int]]] = [ (state[0], [19, 28]), (state[1], [61, 39]), (state[2], [1, 6]), (state[3], [10, 17]), (state[4], [7, 41]), ] return [ s ^ self.rotate_right( value=s, num_bits=r1, ) ^ self.rotate_right( value=s, num_bits=r2, ) for s, (r1, r2) in rotations ]
[docs] def assert_output( self, dut: HierarchyObject, state: list[int] | None = None, ) -> None: """ Assert the output of the DUT and log the input and output values. Parameters ---------- dut : HierarchyObject The device under test (DUT). state : List[int], optional The input state, by default None. """ # Compute the expected output self.o_state = self._linear_diffusion_layer(state=state) # Get the output state from the DUT o_state: list[int] = [int(x) for x in dut.o_state.value] # Convert the output to a list of integers input_str: str = "{:016X} {:016X} {:016X} {:016X} {:016X}".format( *tuple(x & 0xFFFFFFFFFFFFFFFF for x in state), ) expected_str: str = "{:016X} {:016X} {:016X} {:016X} {:016X}".format( *tuple(x & 0xFFFFFFFFFFFFFFFF for x in self.o_state), ) output_dut_str: str = "{:016X} {:016X} {:016X} {:016X} {:016X}".format( *tuple(x & 0xFFFFFFFFFFFFFFFF for x in o_state), ) dut._log.info("Input state : " + input_str) dut._log.info("Expected state : " + expected_str) dut._log.info("Output state : " + output_dut_str) dut._log.info("") # Check the output if expected_str != output_dut_str: error_msg: str = f"Expected: {expected_str}\nReceived: {output_dut_str}" raise ValueError(error_msg)