#!/usr/bin/env python3.4

import sys

#Compute all primes between 0 and max
#Input:
#minPrime: smallest integer to check if prime
#maxPrime: biggest integer to check if prime
#Output:
#primes: ordered list of the primes between 1 and maxPrime
def computePrimes(minPrime,maxPrime):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
	return primes

# Compute a public key
# Input:
# p, q: primes. Type: integer. Constraints: 50< p, q < 300, p != q
# Output:
# (n, e): tupel of integers representing the public key (e>50)
def computePubKey(p, q):
    # some tests so the user does not start us with values that are incorrect for our implementation
    assert (p > 50)
    assert (q > 50)
    assert (p < 600)
    assert (q < 600)
    assert (p != q)
    # Note: we do not do any primality tests here!

    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    # n and e must be integers!
    return (n, e)
	
	

	
# e and phi(n) are input, both integers
# Compute a private key
# Input:
# e, phi(n): as in lecture. Type: integer.
# Output:
# d: private key. Type: integer
def computePrivKey(e, phi):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    # d is the private key, an integer
    return d


# gcd() uses eea()
# Input:
# a, b: numbers to work on. Type: integer
# Output:
# gcd: the gcd. Type: integer
def gcd(a, b):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    return gcd


# eea is the Extended Euclidean Algorithm
# Input:
# a, b: numbers to work on. Type: integer
# Output:
# (x, y): numbers for which ax + by = gcd(a,b) holds. Type: tupel of integers
def eea(a, b):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    return (x, y)



# Compute phi(n) if input is a product of two primes
# Input:
# p, q: primes. Type: integer
# Output:
# o: phi(n). Type: integer
def computePhi(p, q):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    return o


# Compute an encrypted message
# Input:
# m: the message. Type: integer. Constraint: m < n
# pubkey: public key. Type: tupel of integers (n, e)
# Output:
# ciphertext: encrypted message. Type: integer
def encrypt(m, e, n):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    return ciphertext


# Decrypt an encrypted message
# Input:
# c: the ciphertext. Type: integer
# d: the private key. Type: integer
# n: the product of p and q. Type: integer
# Output:
# decryptedtext: the decrypted message. Type: integer
def decrypt(c, d, n):
    # [YOUR TASK STARTS HERE]

    # ...

    # [YOUR TASK ENDS HERE]
    return decryptedtext

# A simple padding scheme.
# This is just for the purpose that padding is important for RSA, this padding scheme is very insecure and probably even weakens RSA!
# For real use (outside of this challenge) have a look at e.g. RSA-OEAP 
# Don't change anything here!
def pad(text, n):
    if text == 0 or text == 1:
        print("Can't handle 0 or 1 as text")
        sys.exit()
    bits=math.floor(math.log(n,2))+1 
    if math.floor(math.log(text,2))+1>= bits:
        print("Length of text must be at least 1 bit shorter then n to avoid problems with our padding")
        sys.exit()
    mask=0 
    for i in reversed(range(0,bits-1)): #compute the mask we XOR with: ...0101, 1 bit shorter than n
        if i%2==0:
            mask+=int(math.pow(2,i)) 
    padded=text^mask
    if padded == 0 or padded == 1: #if our text after padding would be 0 or 1 don't pad
        return text
    else:
        return padded

# Removes the padding from a given text.
# Don't change anything here!
def unpad(text,n):
    return pad(text,n)



#If you simply execute this file instead of the client, your implementation will be checked against some reference values:
if __name__ == "__main__":
    #testsamples format: list of (p,q,n,e,phi,d,m,c)
    #IMPORTANT: we chose the smallest possible e, if you did this too everything should work if your implementation is working.
    #If you chose another (working) e, your implementation may be ok, but this test will tell you otherwise.
    testsamples=[(53,59,3127,51,3016,1715,423,1369),(239,293,70027,53,69496,5245,39545,20070),(131,223,29213,53,28860,9257,7424,8958),(577,599,345623,53,344448,337949,123456,1831), (271,563,152573,53,151740,148877,66666,684)]
    errors = 0
    print("Initiate Test:")
    for test in testsamples:

        p = test[0]
        q = test[1]
        m = test[6]
        print("checking for p=%d, q=%d, m=%d" % (p,q,m))
        #check public key parts
        n, e = computePubKey(p,q)
        if n != test[2]:
            print("n wrong: expected %d, calculated %d" % (test[2],n))
            errors += 1
        else:
            print("n ok (%d)" % n)
        if e != test[3]:
            print("e wrong: expected %d, calculated %d (did you choose the smallest possible e?)" % (test[3],e))
            errors += 1
        else:
            print("e ok (%d)" % e)

        #check phi
        phi = computePhi(p,q)
        if phi!=test[4]:
            print("phi wrong: expected %d, calculated %d" % (test[4],phi))
            errors += 1
        else:
            print("phi ok (%d)" % phi)

        #check private key
        d = computePrivKey(e, phi)
        if d!=test[5]:
            print("d wrong: expected %d, calculated %d" % (test[5],d))
            errors += 1
        else:
            print("d ok (%d)" % d)
        
        #santity check
        # TODO for those of you who worked on the old version, here is an explanation what we did wrong:
        #   In the old version, we chooe, e as the smallest prime number which is coprime to phi(n). In fact, e may not be a prime number, but it must be coprime to phi(n). Because we restricted e to prime numbers, in the first test case, we chose e as 53 (prime and coprime to phi(n)), but the correct e would be 51 (coprime to phi(n), not prime). Thanks to Michael Kubitza for finding the bug!
        assert(test[3] * test[5] % test[4] == 1)
        assert(gcd(test[3], test[4]) == 1)
        assert(e * d % phi == 1)

        #check encryption
        c = encrypt(m, e, n)
        if c!= test[7]:
            print("c wrong: expected %d, calculated %d" % (test[7],c))
            errors += 1
        else:
            print("c ok (%d)" % c)

        #check decryption
        m2 = decrypt(c, d, n)
        if m2!=m:
            print("decrypted c wrong: expected %d, calculated %d" % (m,m2))
            errors += 1
        else:
            print("decrypted c ok (%d)" % m)
            
        #check decrypt encrypt = id
        assert(decrypt(encrypt(m, e, n), d, n) == m)    
        print("\n--------------------------------------------------------\n")
   
    print("Finished testing:")
    print("%d Errors" % errors)
    if errors>0:
        print("Check output above!")
