"""
Part 1 Server — MAC-then-Encrypt (Receiver Side)
=================================================
Listens for an encrypted payload sent by Part1_client.py, decrypts it,
verifies the HMAC-SHA256 tag, and reports whether the message is accepted
or rejected.

Shared secrets (must match the client):
  K1  – 32-byte HMAC key
  K2  – 16-byte AES-CBC key
  IV  – 16-byte AES-CBC initialisation vector

Run this server before starting the client.

Student name : ___________________________
Student ID   : ___________________________
"""

import hmac
import hashlib
import socket
import struct

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.padding import PKCS7
from cryptography.hazmat.backends import default_backend

# Shared secrets (must be identical on the client)
# TODO 1: Set the same three shared secrets you used in Part1_client.py.
#   K1 – 32-byte HMAC key  ->  b"0123456789abcdef0123456789abcdef"
#   K2 – 16-byte AES key   ->  b"abcdef1234567890"
#   IV – 16-byte AES IV    ->  b"1234567890abcdef"
K1 = None   # TODO 1a
K2 = None   # TODO 1b
IV = None   # TODO 1c

HOST = "127.0.0.1"
PORT = 65432
TAG_LENGTH = 32   # HMAC-SHA256 always produces 32 bytes


# Crypto helpers

def compute_mac(key: bytes, message: bytes) -> bytes:
    """
    Compute and return an HMAC-SHA256 tag over `message` using `key`.

    TODO 2: Use hmac.new() with hashlib.sha256 as the digestmod and call .digest().
            This must produce the same result as the identical function in the client.
    """
    # TODO 2 – implement this function (remove the placeholder below)
    pass


def unpad_data(padded_data: bytes) -> bytes:
    """
    Remove PKCS7 padding from `padded_data`.

    TODO 3: Create a PKCS7(128) unpadder.
            Call .update(padded_data) and .finalize(), concatenate the results,
            and return the unpadded bytes.
    """
    # TODO 3 – implement this function (remove the placeholder below)
    pass


def decrypt_data(key: bytes, iv: bytes, ciphertext: bytes) -> bytes:
    """
    Decrypt an AES-CBC ciphertext and return the unpadded plaintext.

    TODO 4: 1. Create a Cipher object with algorithms.AES(key) and modes.CBC(iv).
            2. Create a decryptor with cipher.decryptor().
            3. Decrypt: padded = decryptor.update(ciphertext) + decryptor.finalize()
            4. Call unpad_data() to strip the PKCS7 padding.
            5. Return the plaintext bytes.
    """
    # TODO 4 – implement this function (remove the placeholder below)
    pass


# Receiver logic 

def receiver(ciphertext: bytes) -> None:
    """
    Decrypt the ciphertext, verify the embedded MAC tag, and print the result.

    TODO 5: Wrap everything in try/except so that any decryption or unpadding
            error is caught and treated as a rejection.

            Inside the try block:
            1. Decrypt the ciphertext using decrypt_data() with K2 and IV.
               This gives back the payload = message_bytes || tag.
            2. Split the payload:
                 message_bytes = payload[:-TAG_LENGTH]
                 received_tag  = payload[-TAG_LENGTH:]
            3. Recompute the MAC over message_bytes using K1.
            4. Print the decrypted message, received tag (hex), and computed tag (hex).
            5. Use hmac.compare_digest() to compare the two tags securely.
               Print "Tag matched -- message ACCEPTED" or
                     "Tag not matched -- message REJECTED" accordingly.

            In the except block:
               Print the error and print "Tag not matched -- message REJECTED".
    """
    print("\n=== Receiver (Server) ===")
    # TODO 5 – implement the try/except block described above


# Network layer 

def recv_all(conn: socket.socket, length: int) -> bytes:
    """
    Read exactly `length` bytes from `conn`, looping until all bytes arrive.

    TODO 6: Use a while loop that calls conn.recv() and appends chunks to a
            buffer until len(buffer) == length.
            If conn.recv() returns an empty bytes object, the connection was
            closed early -- raise a ConnectionError.
            Return the complete buffer.
    """
    # TODO 6 – implement this function (remove the placeholder below)
    pass


def start_server() -> None:
    """
    Bind to HOST:PORT, accept one client connection, receive the ciphertext,
    and call receiver() to verify it.

    The wire protocol sends a 4-byte big-endian integer (the ciphertext length)
    followed by that many bytes of ciphertext.

    TODO 7: 1. Create a TCP socket, set SO_REUSEADDR, bind to (HOST, PORT), and listen.
            2. Call srv.accept() to block until the client connects.
            3. Receive the 4-byte length prefix with recv_all(conn, 4) and unpack it
               with struct.unpack(">I", raw_len)[0].
            4. Receive the ciphertext bytes with recv_all(conn, msg_len).
            5. Print the number of bytes received and the ciphertext in hex.
            6. Call receiver(ciphertext) to perform verification.
            Use nested `with` statements so both the server socket and the
            connection socket are closed automatically.
    """
    # TODO 7 – implement this function (remove the placeholder below)
    pass


if __name__ == "__main__":
    start_server()
