#-*- coding: utf-8 -*-
# (c) 2015 Anders Andersen
# See http://www.cs.uit.no/~aa/dist/tools/py/COPYING for details

# Load modules
import sys, os
from noop.ip.tcp import *
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding

# argv[1]: RSA key file name (reads)
# argv[2]: host*
# argv[3]: port*
# argv[4]: password* (if RSA key file is password protected)
# *optional
host = "localhost"
port = 3456
passwd = None
if len(sys.argv) > 1:
    if len(sys.argv) > 4:
        passwd = sys.argv[4].encode()
    frsakey = open(sys.argv[1], "rb")
    rsakey = serialization.load_pem_private_key(
        frsakey.read(),
        password=passwd,
        backend=default_backend())
    frsakey.close()
else:
    sys.exit("Usage: %s privkey-file [host [port]]" % (sys.argv[0],))
if len(sys.argv) > 2:
    host = sys.argv[2]    
if len(sys.argv) > 3:
    port = int(sys.argv[3])

# A class to receive encrypted data
class SecComReceive:

    # Save address and password private key
    def __init__(self, address, privkey):
        self.addr = address
        self.privkey = privkey
        self.decryptor = None

    # Receive cipher text and decrypt it
    def receive(self):
        if not self.decryptor:
            tmp = self.privkey.decrypt(
                tcpreceive(self.addr),
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    algorithm=hashes.SHA1(),
                    label=None))
            irv = tmp[:algorithms.AES.block_size//8]
            aeskey = tmp[algorithms.AES.block_size//8:]
            cipher = Cipher(algorithms.AES(aeskey), modes.CTR(irv), backend=default_backend())
            self.decryptor = cipher.decryptor()
        return (self.decryptor.update(tcpreceive(self.addr)) + self.decryptor.finalize())

# Create object to receive data from client, and then receive the data
scr = SecComReceive(IPaddr(node=host, port=port), rsakey)
print(scr.receive().decode())