# -*- coding: utf-8 -*-
"""The NOOP pubcrypto module

Copyright © 2015-2016, Anders Andersen, The Arctic University of Norway.
See http://www.cs.uit.no/~aa/dist/tools/noop/COPYING (../COPYING) for
details.

"""


# Default key size
ASYMKEYSIZE = 2048


# Import system modules
import sys, os


# Python 3 only!
assert (sys.version_info[0] > 2 and sys.version_info[1] > 3), \
       "This NOOP module \"%s\" is Python 3.4 (or greater) only!" % (__name__,)


# Import  modules
from cryptography.hazmat.primitives.asymmetric import dsa, rsa, padding
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.hashes import SHA1
from cryptography.hazmat.backends import default_backend


# Load noop libraries (if available)
try:
    from noop.core.signature import signature, one, opt
    from noop.core.misc import idstr
    from noop.crypto.crypto import Key, validkey, isvalidkey, parse, \
        PUBLIC, PRIVATE, KEYCLS, KEYTYPE
    from noop.crypto.password import Password
except ImportError:
    def signature(f): return f
    class one:
        def __init__(*args): pass
    opt = one
    def idstr(self=None): return __name__
    from crypto import Key, validkey, isvalidkey, parse, \
        PUBLIC, PRIVATE, KEYCLS, KEYTYPE
    from password import Password


@signature(exc=[ValueError])
def _loadpubkey(pem:bytes, keyType:type, passwd:opt(Password)=None) -> object:
    """Load public key from PEM

    Load a public key from a PEM representation (usually read from a file).

    """
    try:
        key = serialization.load_pem_public_key(pem, backend=default_backend())
    except ValueError as err:
        try:
            if passwd:
                privkey = serialization.load_pem_private_key(
                    pem, bytes(passwd), backend=default_backend())
            else:
                privkey = serialization.load_pem_private_key(
                    pem, None, backend=default_backend())
            key = privkey.public_key()
        except ValueError as err:
            raise ValueError(idstr() + ": Unable to load public key from PEM.")
    if isvalidkey(type(key), keyGroup=PUBLIC, keyCort=KEYTYPE):
        return key
    else:
        raise ValueError(idstr() + ": PEM is not a valid (public) key type.")


@signature(exc=[ValueError])
def _loadprivkey(pem:bytes, keyType:type, passwd:opt(Password)=None) -> object:
    """Load private key from PEM

    Load a private key from a PEM representation (usually read from a file).

    """
    if passwd:
        key = serialization.load_pem_private_key(
            pem, bytes(passwd), backend=default_backend())
    else:
        key = serialization.load_pem_private_key(
            pem, None, backend=default_backend())
    if isvalidkey(type(key), keyGroup=PRIVATE, keyCort=KEYTYPE):
        return key
    else:
        raise ValueError(idstr() + ": PEM is not a valid (private) key type.")


class AsymKey(Key):

    @signature
    def setkey(self, key: object):
        self.key = key
        self.keylength = key.key_size // 8

    @signature
    def __repr__(self) -> str:
        if hasattr(self, "key"):
            return self.__class__.__name__ + "(" + bytes(self).decode() + ")"
        else:
            return self.__class__.__name__ + "(nokey=True)"



    
class PubKey(AsymKey):

    _keyType = None
    
    @signature
    def __init__(self,
                 pem:opt(bytes) = None,
                 passwd:opt(Password) = None,
                 nokey:opt(bool) = False):
        if passwd:
            self.passwd = passwd
        if not nokey:
            self.key = None
            if pem and self._keyType:
                if passwd:
                    self.setkey(_loadpubkey(pem, self._keyType, passwd))
                else:
                    self.setkey(_loadpubkey(pem, self._keyType))
            elif pem and not self._keyType:
                raise ValueError(idstr(self) + ": Unknown public key type")


class RSAKey(AsymKey):
    
    _padding = padding.OAEP
    _hash = SHA1
    _mgf = padding.MGF1

    @signature
    def __getattr__(self, name:str) -> object:
        if name == "payloadlength":
            if hasattr(self, "keylength"):
                self.payloadlength = self.keylength - 2 - \
                                     2*self._hash.digest_size
                return self.payloadlength
            else:
                return 0
        else:
            raise AttributeError(idstr(self) + ": unknown attribute " + name)

    
@validkey
class RSAPubKey(PubKey, RSAKey):

    _keyType = rsa.RSAPublicKey
    _keyGroup = PUBLIC
    
    @signature
    def encrypt(self, data:bytes) -> bytes:
        if len(data) > self.payloadlength:
            raise ValueError(idstr(self) + ": to large payload")
        return self.key.encrypt(
            data,
            self._padding(
                mgf=self._mgf(algorithm=self._hash()),
                algorithm=self._hash(),
                label=None))

    # def verify(self, ...):


class DSAKey(AsymKey):
    pass


@validkey
class DSAPubKey(DSAKey, PubKey):
    _keyType = dsa.DSAPublicKey
    _keyGroup = PUBLIC


class PrivKey(AsymKey):

    _keyType = None
    _pubKeyClass = PubKey

    @signature
    def __init__(self,
                 pem:opt(bytes) = None,
                 passwd:opt(Password) = None,
                 keysize:opt(int) = ASYMKEYSIZE,
                 nokey:opt(bool) = False):
        if not nokey:
            self.key = None #?
            self.passwd = passwd
            if pem and self._keyType:
                if passwd:
                    self.setkey(_loadprivkey(pem, self._keyType, bytes(passwd)))
                else:
                    self.setkey(_loadprivkey(pem, self._keyType))
            elif pem and not self._keyType:
                raise ValueError(idstr(self) + ": Unknown private key type")
            else:
                if hasattr(self, "_keygen"):
                    self.setkey(
                        self._keygen(
                            public_exponent=self._pubexp,
                            key_size=keysize,
                            backend=default_backend()))
                else:
                    raise ValueError(idstr(self) + ": Unable to generate key")

    @signature
    def getpubkey(self) -> PubKey:
        pubkey = self._pubKeyClass()
        pubkey.setkey(self.key.public_key())
        return pubkey


@validkey
class RSAPrivKey(RSAKey, PrivKey):

    _keyType = rsa.RSAPrivateKey
    _keyGroup = PRIVATE
    _pubKeyClass = RSAPubKey
    _encoding = serialization.Encoding.PEM
    _format = serialization.PrivateFormat.PKCS8
    _keygen = rsa.generate_private_key
    _pubexp = 65537

    @signature
    def __parse__(self, msg:bytes, passwd:opt(Password) = None) -> tuple:
        if password:
            return (msg, passwd)
        else:
            return (msg,)

    @signature
    def __bytes__(self) -> bytes:
        if self.passwd:
            return self.key.private_bytes(
                encoding=self._encoding,
                format=self._format,
                encryption_algorithm=serialization.BestAvailableEncryption(
                    bytes(self.passwd)))
        else:
            return self.key.private_bytes(
                encoding=self._encoding,
                format=self._format,
                encryption_algorithm=serialization.NoEncryption())

    @signature
    def __load__(self, msg:bytes, passwd:opt(Password) = None) :
        if passwd:
            self.__init__(msg, passwd)
        else:
            self.__init__(msg)

    @signature
    def decrypt(self, data:bytes) -> bytes:
        return self.key.decrypt(
            data,
            self._padding(
                mgf=self._mgf(algorithm=self._hash()),
                algorithm=self._hash(),
                label=None))


@validkey
class DSAPrivKey(DSAKey, PrivKey):

    _keyType = dsa.DSAPrivateKey
    _keyGroup = PRIVATE
    _pubKeyClass = DSAPubKey


# Use validKeys to populate these
validPubKeyClasses = [RSAPubKey, DSAPubKey]
validPubKeyTypes = list(map(lambda x: x._keyType, validPubKeyClasses))
validPrivKeyClasses = [RSAPrivKey, DSAPrivKey]
validPrivKeyTypes = list(map(lambda x: x._keyType, validPrivKeyClasses))

@signature
def loadpubkey(pem:bytes, passwd:opt(Password)=None) -> PubKey:
    if passwd:
        (cls, args) = parse(pem, passwd)
    else:
        (cls, args) = parse(pem)
    # for pkc in validKeys["cls"]:
    #     if isinstance(key, pkc._keyType):
    #         pkey = pkc()
    #         break
    # else:
    #     raise ValueError(idstr() + ": PEM is not a recognized public key.")
    # pkey.setkey(key)
    return cls(*args)


@signature
def loadprivkey(pem:bytes, passwd:opt(Password)=None) -> PrivKey:
    if passwd:
        key = _loadprivkey(pem, PrivKey, passwd)
    else:
        key = _loadprivkey(pem, PrivKey)
    for pkc in validPrivKeyClasses:
        if isinstance(key, pkc._keyType):
            pkey = pkc()
            break
    else:
        raise ValueError(idstr() + ": PEM is not a recognized private key.")
    pkey.setkey(key)
    return pkey


@signature
def pwprotectkey(key:Key, passwd:Password) -> bytes:
    if type(key) in [pubKey, privKey]:
        key.passwd = passwd
        return key.__bytes__()
    else:
        raise ValueError(idstr() + ": Unable to password protect key.")

#??? (loadprivkey?)
#@signature(exc=[ValueError])
#def pwexposekey(msg:bytes, keyType:type, passwd:Passwd) -> Key:
#    key = keyType(nokey=True)
#    return key.__exposed__(msg, passwd)