from copy import deepcopy

# This program replicates the behavior of the electro-mechanical Enigma machines
# used by the Germans in WWII. The machine here is a 3-rotor model with 
# a plug board. 

# Plug board
s = ['ar', 'dq', 'es', 'kt', 'lp', 'mf', 'yz']

# Simple rotation
p = ['abcdefghijklmnopqrstuvwxyz']

# Rotor wheels
r1 = ['aeltphqxru', 'bknw', 'cmoy', 'dfg', 'iv', 'jz', 's']
r2 = ['a', 'bj', 'cdklhup', 'esz', 'fixvyomw', 'gr', 'nt', 'q']
r2_notch_position = 5
r3 = ['abdhpejt', 'cflvmzoyqirwukxsg', 'n']
r3_notch_position = 10
# Reflector
r = ['ay', 'br', 'cu', 'dh', 'eq', 'fs', 'gl', 'ip', 'jx', 'kn', 'mo', 'tz', 'vw']

# Choose rotors and their position
n = r2
m = r3
l = r1

n_notch_position = r2_notch_position
m_notch_position = r3_notch_position

# Initial rotor settings
n_rotor_initial = 'a'
m_rotor_initial = 'a'
l_rotor_initial = 'a'

# Initial ring settings
n_ring_initial = 'A'
m_ring_initial = 'A'
l_ring_initial = 'A'

# Keep track of original rotor positions
n_p_count = ord(n_ring_initial) - ord('A')
m_p_count = ord(m_ring_initial) - ord('A')
l_p_count = ord(l_ring_initial) - ord('A')

# Keep track of original ring positions
n_p_power = ord(n_rotor_initial) - ord('a')
m_p_power = ord(m_rotor_initial) - ord('a')
l_p_power = ord(l_rotor_initial) - ord('a')

# Take into account the notches on the rotors
n_count = 26 - n_p_count - n_notch_position
m_count = 26 - m_p_count - m_notch_position

# Save initial positions for reinitialization
n_count_hold, m_count_hold = n_count, m_count

def invert(perm):
	return [an_s[::-1] for an_s in perm]

# Invert permutations for reverse paths through rotors
s_inv = s # No reason to actually invert s as it's composed only of pairs
p_inv = invert(p)
n_inv = invert(n)
l_inv = invert(l)
m_inv = invert(m)

# Simple check that all bases are covered
def check_perm(p):
	count = 0
	for a_p in p:
		count += len(a_p)
		
	assert(count == 26)

# Run a character through the machine, end-to-end
def convert_letter(c_in):
	global machine
	c_out = c_in
	
	preprocess_rotation()
	for a_comp in machine:
		c_out = do_ring(a_comp, c_out)
		
	return c_out

# Run a character through a ring or plug board
def do_ring(r, c):
	result = c # to accomodate plug board settings
	found = False
	
	for perm in r:
		i = perm.find(c)
		if i == -1:
			continue
		elif i == len(perm) - 1:
			result = perm[0]
			found = True
			break
		else:
			result = perm[i + 1]
			found = True
			break
			
	assert(found == True or result == c)

	return result

def preprocess_rotation():
	global machine
	global n_count, m_count
	global n, n_inv, m, m_inv, l, l_inv
	
	rotate_m, rotate_l  = False, False
	n_count += 1
	if n_count == 26:
		rotate_m = True
		n_count = 0
		m_count += 1
		if m_count == 26:
			rotate_l = True
			m_count = 0

	result = []
	for a_comp in machine:
		if a_comp == n or a_comp == n_inv:
			result.append(p)
			result.append(a_comp)
			result.append(p_inv)
		elif rotate_m and (a_comp == m or a_comp == m_inv):
			result.append(p)
			result.append(a_comp)
			result.append(p_inv)
		elif rotate_l and (a_comp == l or a_comp == l_inv):
			result.append(p)
			result.append(a_comp)
			result.append(p_inv)
		else:
			result.append(a_comp)
			
	machine = deepcopy(result)
	

def reinitialize():
	global machine, original_state, n_count, m_count
	
	machine = deepcopy(original_state) # Initial state of the machine
	n_count, m_count = n_count_hold, m_count_hold # Initial notch positions

# Start of program

# machine = [s, n, m, l, r, l_inv, m_inv, n_inv, s_inv]

# Build machine
machine = [s] # Start with plugboard

for _ in range(n_p_power):
	machine.append(p)
for _ in range(n_p_count):
	machine.append(p_inv)
machine.append(n)
for _ in range(n_p_power):
	machine.append(p_inv)
for _ in range(n_p_count):
	machine.append(p)

for _ in range(m_p_power):
	machine.append(p)
for _ in range(m_p_count):
	machine.append(p_inv)
machine.append(m)
for _ in range(m_p_power):
	machine.append(p_inv)
for _ in range(m_p_count):
	machine.append(p)

for _ in range(l_p_power):
	machine.append(p)
for _ in range(l_p_count):
	machine.append(p_inv)
machine.append(l)
for _ in range(l_p_power):
	machine.append(p_inv)
for _ in range(l_p_count):
	machine.append(p)

machine.append(r)

for _ in range(l_p_power):
	machine.append(p)
for _ in range(l_p_count):
	machine.append(p_inv)
machine.append(l_inv)
for _ in range(l_p_power):
	machine.append(p_inv)
for _ in range(l_p_count):
	machine.append(p)

for _ in range(m_p_power):
	machine.append(p)
for _ in range(m_p_count):
	machine.append(p_inv)
machine.append(m_inv)
for _ in range(m_p_power):
	machine.append(p_inv)
for _ in range(m_p_count):
	machine.append(p)

for _ in range(n_p_power):
	machine.append(p)
for _ in range(n_p_count):
	machine.append(p_inv)
machine.append(n_inv)
for _ in range(n_p_power):
	machine.append(p_inv)
for _ in range(n_p_count):
	machine.append(p)
		
machine.append(s_inv) # End with plugboard inverted

# Save initial state for reinitialization
original_state = deepcopy(machine)

# End of program

if __name__ == "__main__":
	pt, ct = 'zizouxisxaxcat', ''
	for c_in in pt:
		c_out = convert_letter(c_in)
		print('{0:s} -> {1:s}'.format(c_in, c_out))
		ct += c_out

	print('\nReinitializing\n')
	reinitialize()

	for c_in in ct:
		c_out = convert_letter(c_in)
		print('{0:s} -> {1:s}'.format(c_in, c_out))

# Output:

# z -> g
# i -> a
# z -> e
# o -> y
# u -> c
# x -> s
# i -> b
# s -> i
# x -> y
# a -> i
# x -> y
# c -> j
# a -> k
# t -> f
#
# Reinitializing
# 
# g -> z
# a -> i
# e -> z
# y -> o
# c -> u
# s -> x
# b -> i
# i -> s
# y -> x
# i -> a
# y -> x
# j -> c
# k -> a
# f -> t