"""
Part 1 Client — MAC-then-Encrypt (Sender Side)
===============================================
Encrypts a message with the MAC-then-Encrypt scheme and sends the
ciphertext to Part1_server.py over TCP.

Two test cases are included:
  Case 1 – valid ciphertext    → server should report "message ACCEPTED"
  Case 2 – tampered ciphertext → server should report "message REJECTED"

Run the server first, then run this client.

Student name : ___________________________
Student ID   : ___________________________
"""

import hmac
import hashlib
import socket
import struct

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.padding import PKCS7
from cryptography.hazmat.backends import default_backend

# Shared secrets (must be identical on the server)
# TODO 1: Set the three shared secrets exactly as specified below.
#   K1 – a 32-byte HMAC key   (use the ASCII string "0123456789abcdef0123456789abcdef")
#   K2 – a 16-byte AES key    (use the ASCII string "abcdef1234567890")
#   IV – a 16-byte AES IV     (use the ASCII string "1234567890abcdef")
#   Hint: prefix string literals with b"..." to get bytes, e.g. b"hello"
K1 = None   # TODO 1a – replace None with the correct 32-byte key
K2 = None   # TODO 1b – replace None with the correct 16-byte key
IV = None   # TODO 1c – replace None with the correct 16-byte IV

HOST = "127.0.0.1"
PORT = 65432


# Crypto helpers

def compute_mac(key: bytes, message: bytes) -> bytes:
    """
    Compute and return an HMAC-SHA256 tag over `message` using `key`.

    TODO 2: Use hmac.new() with hashlib.sha256 as the digestmod.
            Call .digest() on the result to get raw bytes.
            Return the tag bytes.
    """
    # TODO 2 – implement this function (remove the placeholder below)
    pass


def pad_data(data: bytes) -> bytes:
    """
    Apply PKCS7 padding to `data` for a 128-bit (16-byte) block size.

    TODO 3: Create a PKCS7(128) padder, call .update(data) and .finalize(),
            concatenate the two results, and return the padded bytes.
    """
    # TODO 3 – implement this function (remove the placeholder below)
    pass


def encrypt_data(key: bytes, iv: bytes, plaintext: bytes) -> bytes:
    """
    Pad `plaintext` with PKCS7, then encrypt using AES-CBC.

    TODO 4: 1. Call pad_data() to pad the plaintext.
            2. Create a Cipher object with algorithms.AES(key) and modes.CBC(iv).
            3. Create an encryptor with cipher.encryptor().
            4. Return encryptor.update(padded) + encryptor.finalize().
    """
    # TODO 4 – implement this function (remove the placeholder below)
    pass


# Sender logic 

def sender(message: str) -> bytes:
    """
    Build and return the ciphertext for a given plaintext message.

    TODO 5: Follow the MAC-then-Encrypt steps in order:
            1. Encode `message` to bytes with UTF-8.
            2. Compute the HMAC tag over the message bytes using K1.
            3. Concatenate message bytes and tag to form the payload.
            4. Encrypt the payload with K2 and IV.
            5. Print the original message, the tag (hex), and the ciphertext (hex).
            6. Return the ciphertext.
    """
    # TODO 5 – implement this function (remove the placeholder below)
    pass


# Network layer 

def send_ciphertext(ciphertext: bytes) -> None:
    """
    Open a TCP connection to the server and send the ciphertext.

    The wire protocol uses a 4-byte big-endian length prefix followed by the
    ciphertext bytes.  The provided struct.pack line below handles the prefix —
    you only need to complete the socket setup.

    TODO 6: 1. Create a TCP socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM).
            2. Connect to (HOST, PORT).
            3. Send the length-prefixed ciphertext with sock.sendall().
               Use:  struct.pack(">I", len(ciphertext)) + ciphertext
            4. Print how many bytes were sent.
            Wrap everything in a `with` statement so the socket closes automatically.
    """
    # TODO 6 – implement this function (remove the placeholder below)
    pass


# Test cases 

def case_1_valid_message() -> None:
    """
    Case 1: send a legitimately signed and encrypted message.
    The server should report: Tag matched — message ACCEPTED

    TODO 7: 1. Set message = "Transfer $100 to Bob"
            2. Call sender() to produce the ciphertext.
            3. Call send_ciphertext() to transmit it to the server.
    """
    print("\n==============================")
    print("Case 1: Valid Message")
    print("==============================")
    # TODO 7 – implement the three steps described above


def case_2_tampered_message() -> None:
    """
    Case 2: send a message that has been tampered with AFTER the MAC was computed.
    The server should report: Tag not matched — message REJECTED

    The trick: compute the MAC for the ORIGINAL message ("Transfer $100 to Bob"),
    but then build the payload using a DIFFERENT message ("Transfer $1000 to Bob")
    paired with that stale tag.  Because the tag no longer matches the payload,
    the server must reject it.

    TODO 8: 1. Set original_message = "Transfer $100 to Bob" and encode it to bytes.
            2. Compute the MAC tag for the original message bytes using K1.
            3. Set tampered_message = "Transfer $1000 to Bob".
            4. Build tampered_payload = tampered_message (as bytes) + original_tag.
            5. Encrypt tampered_payload with K2 and IV.
            6. Print the tampered message and ciphertext (hex).
            7. Call send_ciphertext() to transmit the tampered ciphertext.
    """
    print("\n==============================")
    print("Case 2: Tampered Message")
    print("==============================")
    # TODO 8 – implement the seven steps described above


# Entry point 

def main() -> None:
    print("Select a test case to run:")
    print("  1 — Valid message    (expected: ACCEPTED)")
    print("  2 — Tampered message (expected: REJECTED)")
    choice = input("Enter 1 or 2: ").strip()

    if choice == "1":
        case_1_valid_message()
    elif choice == "2":
        case_2_tampered_message()
    else:
        print("Invalid choice. Please enter 1 or 2.")


if __name__ == "__main__":
    main()
