#-*- coding: utf-8 -*-
# (c) 2015 Anders Andersen
# See http://www.cs.uit.no/~aa/dist/tools/py/COPYING for details

# Load modules
import sys, os
from noop.ip.tcp import *
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding

# Define key size (AES key)
KEYSIZE = 256

# argv[1]: RSA key file name (reads)
# argv[2]: host*
# argv[3]: port*
# *optional
host = "localhost"
port = 3456
if len(sys.argv) > 1:
    frsakey = open(sys.argv[1], "rb")
    rsakey = serialization.load_pem_public_key(
        frsakey.read(),
        backend=default_backend())
    frsakey.close()
else:
    sys.exit("Usage: %s pubkey-file [host [port]]" % (sys.argv[0],))
if len(sys.argv) > 2:
    host = sys.argv[2]    
if len(sys.argv) > 3:
    port = int(sys.argv[3])

# A class to send encrypted data
class SecComSend:
    
    # Save address and public key, and generate secret shared key
    def __init__(self, address, pubkey):
        self.addr = address
        self.pubkey = pubkey
        self.aeskey = os.urandom(KEYSIZE//8)
        self.irv = os.urandom(algorithms.AES.block_size//8)
        cipher = Cipher(
            algorithms.AES(self.aeskey),
            modes.CTR(self.irv),
            backend=default_backend())
        self.encryptor = cipher.encryptor()
        self._first_time = True

    # Encrypt message and send it
    def send(self, msg):
        if self._first_time:
            self._first_time = False
            tcpsend(
                self.addr,
                self.pubkey.encrypt(
                    self.irv + self.aeskey,
                    padding.OAEP(
                        mgf=padding.MGF1(algorithm=hashes.SHA1()),
                        algorithm=hashes.SHA1(),
                        label=None)))
        tcpsend(self.addr, (self.encryptor.update(msg) + self.encryptor.finalize()))

# Create object to send encrypted data to server, and then send the data
scs = SecComSend(IPaddr(node=host, port=port), rsakey)
scs.send("hello".encode())
tcpflush()