"""
Hash-then-Sign (Receiver / Verifier Side)
==========================================================
Receives a message and RSA-PSS signature from
sender.py, verifies the signature, and reports the result.

Run this file first, then run sender.py.
"""

import socket
import struct

from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend

HOST = "127.0.0.1"
PORT = 65433

# Prefixed public key (matches the private key in sender.py)
PUBLIC_KEY_PEM = b"""-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvXWPijwNG4W+j9kWm7ed
idQ0TyN8aSTb9gO0JXjHKAcxjOoJb2LvvystwK+F0pxa3PZ2YN9v4gv8AEf/v2J2
RFJw/8oSYXSZ20sZJQk55259+ciVtuUgxYlmmvZ6R3LUEQT8eLVCMw5R/rU3rt8m
7FGkSlwYmSJ+YaWHYpO7UzaZq3/nd3rw4kdscMtZkgc8ZUVGInxLp4rL+dEcLdpW
euf8ILtUGCvWpBYDGso5qkNtYYyC3t9V7FsdqCXv7sHVACsxo3gbgnqwx9SVt1rf
uByShiHRVCw+K5n9gkJBNM2iq/0n276hVKpsNpaWWo7LlM4P/EsxIG/tvNkHWA+p
bQIDAQAB
-----END PUBLIC KEY-----"""

public_key = serialization.load_pem_public_key(
    PUBLIC_KEY_PEM, backend=default_backend()
)

# Crypto helpers

def compute_hash(message: bytes) -> str:
    """Return the hex-encoded SHA-256 digest of `message`.

    Hint: Use hashes.Hash(hashes.SHA256()) from the cryptography library.
          Call .update(message) then .finalize() to get the raw digest bytes,
          and convert to a hex string with .hex().
    """
    # TODO: Create a SHA-256 Hash object
    # TODO: Feed `message` into the hash object with .update()
    # TODO: Finalize and return the result as a hex string
    pass


def verify_signature(public_key, message: bytes, signature: bytes) -> bool:
    """Return True if `signature` is a valid RSA-PSS signature over `message`.

    Hint: Call public_key.verify() inside a try/except block with:
          - padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                        salt_length=padding.PSS.MAX_LENGTH)
          - hashes.SHA256() as the hash algorithm
          If verify() raises InvalidSignature, return False; otherwise return True.
    """
    # TODO: Call public_key.verify(signature, message, <PSS padding>, <SHA256>)
    #       inside a try block
    # TODO: Return True if no exception is raised
    # TODO: Catch InvalidSignature and return False
    pass


# Receiver logic

def receiver(message: bytes, signature: bytes) -> None:
    """Hash the received message, verify the signature, and print the result.

    Steps:
      1. Print the received message (decoded as UTF-8).
      2. Compute its SHA-256 hash using compute_hash() and print it.
      3. Call verify_signature() and print "Signature ACCEPTED" or "Signature REJECTED".
    """
    print("\n=== Receiver (Server) ===")
    # TODO: Print the received message (decode bytes to string)
    # TODO: Compute the SHA-256 hash of `message` using compute_hash() and print it
    # TODO: Call verify_signature(); print ACCEPTED if True, REJECTED if False
    pass


# Network layer

def recv_all(conn: socket.socket, length: int) -> bytes:
    """Read exactly `length` bytes from `conn`, looping until all arrive.

    Hint: Start with an empty bytes object. Keep calling conn.recv() in a
          loop until you have collected exactly `length` bytes total.
          Raise ConnectionError if the connection closes early.
    """
    # TODO: Initialize an empty bytes buffer
    # TODO: Loop until len(buffer) == length:
    #           chunk = conn.recv(length - len(buffer))
    #           if not chunk: raise ConnectionError(...)
    #           append chunk to buffer
    # TODO: Return the complete buffer
    pass


def recv_frame(conn: socket.socket) -> bytes:
    """Read a 4-byte big-endian length prefix, then read that many bytes.

    Hint: Call recv_all(conn, 4) to get the header, unpack it with
          struct.unpack(">I", raw_len)[0], then call recv_all(conn, length).
    """
    # TODO: Read exactly 4 bytes for the length header using recv_all()
    # TODO: Unpack the header as a big-endian unsigned int with struct.unpack()
    # TODO: Read and return exactly that many payload bytes using recv_all()
    pass


def start_server() -> None:
    """Bind to (HOST, PORT), accept connections in a loop, and verify each message.

    Hint: Create a TCP socket, set SO_REUSEADDR, bind, and listen.
          Loop forever: accept a connection, receive two frames (message then
          signature) using recv_frame(), print their sizes, and call receiver().
    """
    # TODO: Create a TCP socket (AF_INET, SOCK_STREAM)
    # TODO: Set socket option SO_REUSEADDR to 1
    # TODO: Bind to (HOST, PORT) and call listen(1)
    # TODO: Print a "Listening on ..." message
    # TODO: Loop forever:
    #           accept a connection and print the client address
    #           receive the message frame using recv_frame()
    #           receive the signature frame using recv_frame()
    #           print the sizes and signature hex
    #           call receiver(message, signature)
    pass


if __name__ == "__main__":
    start_server()
