# -*- coding: utf-8 -*-
"""The NOOP crypto module

Copyright © 2015-2016, Anders Andersen, The Arctic University of Norway.
See http://www.cs.uit.no/~aa/dist/tools/noop/COPYING (../COPYING) for
details.

This module provides the basic classes and functions for crypto
related functionality.  The `Key` class is the base class for all
symmteric and public key classes.  It should not itself be used
directly to create instances of crypto keys.

The `validkey` class decorator is used to mark a class as a valid
crypto key (a class that can be used to create instances of crypto
keys).  

>>> @validkey(validSymKeys)
>>> class AESKey(Key):
>>>     ...
>>> key = AESKey()
>>> if type(key) in validSymKeys:
>>>     ...

The `loadkey` function converts a byte blob to a crypto key of the
given type.  A matching function that does the opposite is not
neccesary since the conversion of a key to bytes will do this.  The
following example creates an AESKey `key` from a byte blob `msg`, and
then converts that key back to a byte blob:

>>> key = loadkey(msg, AESKey)
>>> blob = bytes(key)

"""


# Import system modules
import sys


# Import  modules
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend


# Python 3 only!
assert (sys.version_info[0] > 2 and sys.version_info[1] > 3), \
       "This NOOP module \"%s\" is Python 3.4 (or greater) only!" % (__name__,)


# Load noop libraries (if available)
try:
    from noop.core.signature import signature, one, opt
    from noop.core.misc import idstr    
    from noop.crypto.password import Password
except ImportError:
    def signature(f): return f
    class one:
        def __init__(*args): pass
    opt = one
    def idstr(self=None): return __name__
    from password import Password


# For valid keys information (populated by the `validkey` class decorator)
KEYCLS = 0
KEYTYPE = 1
CLS = 0
GROUP = 1
SECRET = 0
PUBLIC = 1
PRIVATE = 2

# A database of valid key classes and key types. Each class (and type)
# is inserted in the valid list and the appropriate group.
# (KEYCLS( CLS[], GROUP(SECRET[],PUBLIC[],PRIVATE[])),
#  KEYTYPE( CLS[], GROUP(SECRET[],PUBLIC[],PRIVATE[]))
_validKeys = (([],([],[],[])),([],([],[],[])))

# Returns a valid key class from a name or a key type
_keyClassMap = {}


# A decorator for valid keys
def validkey(cls):
    _validKeys[KEYCLS][CLS].append(cls)
    _validKeys[KEYTYPE][CLS].append(cls._keyType)
    _validKeys[KEYCLS][GROUP][cls._keyGroup].append(cls)
    _validKeys[KEYTYPE][GROUP][cls._keyGroup].append(cls._keyType)
    _keyClassMap[cls.__name__] = cls
    _keyClassMap[cls._keyType] = cls    
    return cls


@signature
def isvalidkey(
        keyClass:type,
        keyType:opt(type) = object,
        keyGroup:opt(int) = SECRET,
        keyCort:opt(int) = KEYCLS) -> bool:
    #print("HMM: " + repr(keyClass) + ":" + repr(keyType) + ":" + \
    #      repr(_validKeys[keyCort][GROUP][keyGroup]))
    return (keyClass in _validKeys[keyCort][GROUP][keyGroup] and
            issubclass(keyClass, keyType))


@signature
def loadkey(msg:bytes, passwd:opt(Password) = None) -> object:
    cls = None
    try:
        name, lenstr, rest = msg.split(b":", maxsplit=2)
        cls = _keyClassMap[name.decode()]
        length = int(lenstr)
        keyval = rest[:length]
        dummy, blenstr, iv = rest[length:].split(b":", maxsplit=2)
        blength = int(blenstr)
        if (len(keyval) == length and len(iv) == blength and len(dummy) == 0):
            return cls(keyval, iv)
    except:
        try:
            if not passwd:
                try:
                    key = serialization.load_pem_public_key(
                        msg, backend=default_backend())
                except:
                    key = serialization.load_pem_private_key(
                        msg, None, backend=default_backend())
            else:
                key = serialization.load_pem_private_key(
                    msg, bytes(passwd), backend=default_backend())
            for kt in _validKeys[KEYTYPE][CLS]:
                if isinstance(key, kt):
                    keyobj = _keyClassMap[kt](nokey=True)
                    if passwd:
                        keyobj.passwd = passwd
                    keyobj.setkey(key)
                    return keyobj
        except:
            pass
    raise ValueError(
        idstr() + ": Unable to parse byte representation of key.")
    

# A base Key class
class Key:
    
    @signature
    def __init__(self, keyval:opt(bytes) = None, nokey:opt(bool) = False):
        if not nokey and keyval:
            self.keyval = keyval
            self.keylength = len(self.keyval)

    @classmethod
    def __parse__(cls, msg:bytes) -> tuple:
        try:
            name, lenstr, keyval = msg.split(b":", maxsplit=2)
            length = int(lenstr)
        except:
            raise ValueError(
                idstr() + ": Unable to parse byte representation of key.")
        if name.decode() == cls.__name__ and len(keyval) == length:
            return (keyval,)
        else:
            raise ValueError(
                idstr() + ": Unable to parse byte representation of key.")
    
    @signature
    def __str__(self) -> str:
        if hasattr(self, "keyval"):
            return self.keyval.decode("utf-8", "backslashreplace")
        else:
            return ""

    @signature
    def __repr__(self) -> str:
        if hasattr(self, "keyval"):
            return self.__class__.__name__ + "(" + repr(self.keyval) + ")"
        else:
            return self.__class__.__name__ + "(nokey=True)"

    @signature
    def __bytes__(self) -> bytes:
        if hasattr(self, "keyval"):
            return (
                self.__class__.__name__ + ":" + \
                str(self.keylength) + ":").encode() + \
                self.keyval
        else:
            return (self.__class__.__name__ + "::").encode()

    @signature(exc=[ValueError])
    def __load__(self, msg:bytes):
        self.__init__(*self.__parse__(msg))


#@signature(exc=[ValueError])
#def loadkey(msg:bytes, keyType:type) -> Key:
#    key = keyType(nokey=True)
#    key.__load__(msg)
#    return key


@signature
def encryptpart(key:Key, data:bytes) -> bytes:
    return key.encryptpart(data)

@signature
def encryptdone(key:Key) -> bytes:
    return key.encryptdone()

@signature
def decryptpart(key:Key, data:bytes) -> bytes:
    return key.decryptpart(data)

@signature
def decryptdone(key:Key) -> bytes:
    return key.decryptdone()


# Encrypt data using key
@signature
def encrypt(key:Key, data:bytes) -> bytes:
    return key.encrypt(data)
    

# Decrypt data using key
@signature
def decrypt(key:Key, data:bytes) -> bytes:
    return key.decrypt(data)