# -*- coding: utf-8 -*-
"""The NOOP symcrypto module

Copyright © 2015-2016, Anders Andersen, The Arctic University of Norway.
See http://www.cs.uit.no/~aa/dist/tools/noop/COPYING (../COPYING) for
details.

"""


# Import system modules
import os, sys


# Python 3 only!
assert (sys.version_info[0] > 2 and sys.version_info[1] > 3), \
       "This NOOP module \"%s\" is Python 3.4 (or greater) only!" % (__name__,)


# Import modules
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes


# Load noop libraries (if available)
try:
    from noop.core.signature import signature, one, opt
    from noop.core.misc import idstr
    from noop.crypto.password import SYMKEYSIZE
    from noop.crypto.crypto import Key, validkey, SECRET
except ImportError:
    def signature(f): return f
    class one:
        def __init__(*args): pass
    opt = one
    def idstr(self=None): return __name__
    from password import SYMKEYSIZE
    from crypto import Key, validkey, SECRET


class SymKey(Key):

    _keyType = None
    _keyMode = modes.CTR

    @signature(exc=[ValueError])
    def __init__(self,
                 keyval:opt(bytes) = None,
                 iv:opt(bytes) = None,
                 keysize:opt(int) = None,
                 nokey:opt(bool) = False):
        if self._keyType == None:
            raise ValueError(idstr(self) + ": Unknown key type.")
        if not nokey:
            self.blocklength = self._keyType.block_size//8
            if iv:
                self.setiv(iv)
            else:
                self.setiv()            
            if keyval:
                self.setkey(keyval)
            elif keysize:
                self.makekey(keysize)
            else:
                self.makekey()

    @signature(exc=[ValueError])
    def setiv(self, iv:opt(bytes) = None):
        if iv:
            if len(iv) == self.blocklength:
                self.iv = iv
            else:
                raise ValueError(idstr(self) + ": invalid iv size.")
        else:
            self.iv = os.urandom(self.blocklength)
        self._mkcipher()

    @signature(exc=[ValueError])
    def setkey(self, keyval:bytes):
        if len(keyval)*8 in self._keyType.key_sizes:
            self.keyval = keyval
        else:
            raise ValueError(idstr(self) + ": invalid key length.")
        self.keylength = len(self.keyval)
        self._mkcipher()

    @signature(exc=[ValueError])
    def makekey(self, keysize:opt(int) = None):
        if not keysize:
            keysize = SYMKEYSIZE
        if keysize in self._keyType.key_sizes:
            self.keyval = os.urandom(keysize//8)
        else:
            raise ValueError(idstr(self) + ": invalid key type/size.")
        self.keylength = len(self.keyval)
        self._mkcipher()

    @signature
    def _mkcipher(self):
        if hasattr(self, "keyval") and hasattr(self, "iv"):
            self.cipher = Cipher(
                self._keyType(self.keyval),
                self._keyMode(self.iv),
                backend=default_backend())
            self.encryptor = self.cipher.encryptor()
            self.decryptor = self.cipher.decryptor()

    @signature(exc=[ValueError])
    def __parse__(self, msg:bytes) -> tuple:
        try:
            name, lenstr, rest = msg.split(b":", maxsplit=2)
            length = int(lenstr)
            keyval = rest[:length]
            dummy, blenstr, iv = rest[length:].split(b":", maxsplit=2)
            blength = int(blenstr)
        except:
            raise ValueError(
                idstr(self) + ": Unable to parse byte representation of key.")
        if (name.decode() == self.__class__.__name__
            and len(keyval) == length
            and len(iv) == blength
            and len(dummy) == 0):
            return (keyval, iv)
        else:
            raise ValueError(
                idstr(self) + ": Unable to parse byte representation of key.")
        
    @signature
    def __repr__(self) -> str:
        if hasattr(self, "keyval"):
            return \
                self.__class__.__name__ + "(" + \
                repr(self.keyval) + "," + \
                repr(self.iv) + ")"
        else:
            return self.__class__.__name__ + "(nokey=True)"

    @signature
    def __bytes__(self) -> bytes:
        if hasattr(self, "keyval"):
            return \
                (self.__class__.__name__ + ":" + \
                 str(self.keylength) + ":").encode() + \
                self.keyval + \
                (":" + str(self.blocklength) + ":").encode() + \
                self.iv
        else:
            return (self.__class__.__name__ + "::::").encode()

    @signature
    def encryptpart(self, data:bytes) -> bytes:
        return self.encryptor.update(data)

    @signature
    def encryptdone(self) -> bytes:
        return self.encryptor.finalize()

    @signature
    def decryptpart(self, data:bytes) -> bytes:
        return self.decryptor.update(data)

    @signature
    def decryptdone(self) -> bytes:
        return self.decryptor.finalize()
    
    @signature
    def encrypt(self, data:bytes) -> bytes:
        return self.encryptpart(data) + self.encryptdone()
            
    @signature
    def decrypt(self, data:bytes) -> bytes:
        return self.decryptpart(data) + self.decryptdone()


@validkey
class AESKey(SymKey):
    _keyType = algorithms.AES
    _keyGroup = SECRET