#!/usr/bin/env python3

# I often use reverseBits() here.
# rationale: identity element would be represented as decimal number 1, not as 0x80....00

# https://www.geeksforgeeks.org/reverse-bits-positive-integer-number-python/
def reverseBits(num,bitSize):
    # Convert number into binary representation
    # output will be like bin(10) = '0b10101'
    binary = bin(num)

    # Skip first two characters of binary
    # representation string and reverse
    # remaining string and then append zeros
    # after it. binary[-1:1:-1]  --> start
    # from last character and reverse it until
    # second last character from left
    reverse = binary[-1:1:-1]
    reverse = reverse + (bitSize - len(reverse))*'0'

    # converts reversed binary string into integer
    return int(reverse,2)

# yes, both are XORs:
def gf2_add(x, y):
    return x^y

def gf2_sub(x, y):
    return x^y

def gf128_mul(x, y, R):
    z = 0
    for i in range(128-1, -1, -1):

        if (y >> i) & 1:
            z=gf2_add(z, x)

        # shift and also reduce by R if overflow detected
        # IOW, keep x smaller than R or modulo R
        if x & 1:
            x = gf2_sub(x >> 1, R)
        else:
            x = x >> 1

    return z

def test_mul_vectors():
    R =  0xE1000000000000000000000000000000
    x1 = 0x0388dace60b6a392f328c2b971b2fe78
    y1 = 0x66e94bd4ef8a2c3b884cfa59ca342b2e
    z1 = 0x5e2ec746917062882c85b0685353deb7
    assert gf128_mul(x1, y1, R) == z1

    x2 = 0x5e2ec746917062882c85b0685353de37
    y2 = 0x66e94bd4ef8a2c3b884cfa59ca342b2e
    z2 = 0xf38cbb1ad69223dcc3457ae5b6b0f885
    assert gf128_mul(x2, y2, R) == z2

    x3 = 0xba471e049da20e40495e28e58ca8c555
    y3 = 0xb83b533708bf535d0aa6e52980d53b78
    z3 = 0xb714c9048389afd9f9bc5c1d4378e052
    assert gf128_mul(x3, y3, R) == z3


def gf2_pow_2(x, R):
    return gf128_mul(x, x, R)

def is_odd(n):
    return n&1==1

# almost as in https://en.wikipedia.org/wiki/Exponentiation_by_squaring
def gf2_pow(x, n, R):
    if n==1:
         return x
    if is_odd(n):
        return gf128_mul(x, gf2_pow(gf2_pow_2(x, R), (n-1)//2, R), R)
    else:
        return gf2_pow(gf2_pow_2(x, R), n//2, R)

# return x^{2^{128}-2}
# AKA reciprocal AKA modulo inverse
def gf2_inv(x, R):
    rslt=reverseBits(1, 128) # init to 1

    for i in range(128-1):
        rslt=gf128_mul(rslt, x, R)
        x=gf2_pow_2(x, R)

    return gf2_pow_2(rslt, R)

# simpler version:
def gf2_inv_v2(x, R):
    return gf2_pow(x, 2**128-2, R)

# https://math.stackexchange.com/q/943417
# https://crypto.stackexchange.com/q/17988
# return x^{2^{127}}
def gf2_sqrt(x, R):
    for i in range(127):
        x=gf2_pow_2(x, R)

    return x

# simpler version:
def gf2_sqrt_v2(x, R):
    return gf2_pow(x, 2**127, R)

def gf128_div(N, D, R):
    # TODO: check for zero
    # N = numerator (dividend)
    # D = denominator (divisor)
    i1=gf2_inv(D, R)
    # also, test 2nd version:
    i2=gf2_inv_v2(D, R)
    assert i1==i2

    return gf128_mul(N, i1, R)

def test_sqrt(t, R):
    x=gf2_pow_2(t, R)
    assert gf2_sqrt(x, R)==t
    assert gf2_sqrt_v2(x, R)==t

def test_sqrt2(t, R):
    x=gf2_sqrt(t, R)
    assert gf2_pow_2(x, R)==t
    x=gf2_sqrt_v2(t, R)
    assert gf2_pow_2(x, R)==t

def test_mul():
    R = 0xE1000000000000000000000000000000
    one=reverseBits(1, 128)

    # modulo inverses as found by SageMath:
    x=0xffffeeee0000
    prod=gf128_mul(reverseBits(x, 128), reverseBits(0x12345678, 128), R)
    a=gf128_mul(prod, 0x30c945e1b3d89978798b578e5a2e04bc, R) # mod inv for 0x12345678
    assert reverseBits(a, 128)==x

    prod=gf128_mul(reverseBits(x, 128), reverseBits(2, 128), R)
    a=gf128_mul(prod, 0xc2000000000000000000000000000001, R)
    assert reverseBits(a, 128)==x

    prod=gf128_mul(reverseBits(x, 128), reverseBits(3, 128), R)
    a=gf128_mul(prod, 0x41ffffffffffffffffffffffffffffff, R)
    assert reverseBits(a, 128)==x

    prod=gf128_mul(reverseBits(x, 128), reverseBits(4, 128), R)
    a=gf128_mul(prod, 0x46000000000000000000000000000003, R)
    assert reverseBits(a, 128)==x

    # 1*1 = 1
    a=gf128_mul(one, one, R)
    assert reverseBits(a, 128)==1
    # 1*x = x
    a=gf128_mul(one, 0x12345678, R)
    assert a==0x12345678

def test_div():
    R = 0xE1000000000000000000000000000000
    one=reverseBits(1, 128)

    # test division:
    prod=gf128_mul(0x12345, 0x7689, R)
    assert gf128_div(prod, 0x12345, R)==0x7689
    assert gf128_div(prod, 0x7689, R)==0x12345

    prod=gf128_mul(0xffff, 5, R)
    assert gf128_div(prod, 5, R)==0xffff

    assert gf128_div(one, one, R)==one

    assert gf128_div(0xffff, one, R)==0xffff

def test_sqrt_main():
    R = 0xE1000000000000000000000000000000

    one=reverseBits(1, 128)
    test_sqrt(one, R)
    test_sqrt(1, R)
    test_sqrt(0x12345, R)
    test_sqrt(0x12345678abcdef000011223344556677, R)
    test_sqrt(0xf2345678abcdef00001122334455667f, R)

    test_sqrt2(one, R)
    test_sqrt2(1, R)
    test_sqrt2(0x12345, R)
    test_sqrt2(0x12345678abcdef000011223344556677, R)
    test_sqrt2(0xf2345678abcdef00001122334455667f, R)

def test_pow():
    R = 0xE1000000000000000000000000000000

    one=reverseBits(1, 128)
    assert reverseBits(gf2_pow(one, 2**127, R), 128)==1

    x=0x12345678ee
    r1=gf2_pow(x, 4, R)
    r2=gf2_pow_2(gf2_pow_2(x, R), R)
    assert r1==r2

    x=0x12345678ff
    r1=gf2_pow(x, 3, R)
    r2=gf128_mul(x, gf2_pow_2(x, R), R)
    assert r1==r2

test_mul_vectors()
test_mul()
test_div()
test_pow()
test_sqrt_main()


