#-*- coding: utf-8 -*-
# (c) 2015 Anders Andersen
# See http://www.cs.uit.no/~aa/dist/tools/py/COPYING for details

from sys import argv, stdin, stdout
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding

# argv[1]: RSA key file name (reads)
# argv[2]: AES key file name (reads)
# argv[3]: password (optional)
# stdin: input data
# stdout: output data
if len(argv) > 2:

    # Password, if any
    passwd = None
    if len(argv) > 3:
        passwd = argv[3].encode()

    # Read RSA key
    frsakey = open(argv[1], "rb")
    rsakey = serialization.load_pem_private_key(
        frsakey.read(),
        password=passwd,
        backend=default_backend())
    frsakey.close()

    # Read and decrypt AES key
    faeskey = open(argv[2], "rb")
    aesinfo = rsakey.decrypt(
        faeskey.read(),
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA1()),
            algorithm=hashes.SHA1(),
            label=None))
    irv = aesinfo[:algorithms.AES.block_size//8]
    aeskey = aesinfo[algorithms.AES.block_size//8:]

    # Create AES cipher
    cipher = Cipher(algorithms.AES(aeskey), modes.CTR(irv), backend=default_backend())
    decryptor = cipher.decryptor()

    # Read ciphertext from stdin and write plaintext to stdout
    while True:
        data = stdin.buffer.read()
        if not data: break
        stdout.buffer.write(decryptor.update(data) + decryptor.finalize())