package de.brightbyte.data;

import java.io.UnsupportedEncodingException;
import java.util.AbstractCollection;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

public class Trie<V> implements Map<byte[], V> {
	
	public class KeyIterator implements Iterator<byte[]> {
		//FIXME: fail on concurrent mod!

		protected byte[] buffer = new byte[256];
		protected Node current = null;
		protected int pos = -1;
		protected int level = 0;
		
		public KeyIterator () {
			current = root;
			pos = 0;
			
			findNext();
		}
		
		protected void findNext() {
			Node ch = null;
			while (pos<EOF) {
				ch = current.getChild(pos);
				if (ch!=null) break;
				pos ++;
			}
			
			if (ch!=null) { //found next interesting child, push
				current = ch;
				buffer[level] = (byte)pos;
				pos = 0;
				level++;

				findNext();
			}
			else {
				if (pos==EOF) ch = current.getChild(EOF);
				
				if (ch==null) { 
					if (level==0) { //the end.
						current = null;
					}
					else { //pop, continue with parent
						level--;
						current = current.parent;
						pos = ( buffer[level] & 0xFF );
						pos ++;
						
						findNext();
					}
				}
				else { //found EOF;
					pos ++;
				}
			}
		}
		
		public boolean hasNext() {
			return current != null;
		}

		public byte[] next() {
			byte[] bytes = new byte[level];
			
			System.arraycopy(buffer, 0, bytes, 0, bytes.length);
			
			findNext();
			return bytes;
		}

		public void remove() {
			throw new UnsupportedOperationException(); //TODO...
		}
	}

	public class MyEntry implements Entry<byte[], V> {
		protected byte[] key;
		protected V value;

		public MyEntry(byte[] key, V value) {
			this.key = key;
			this.value = value;
		}

		public byte[] getKey() {
			return key;
		}

		public V getValue() {
			return value;
		}

		public V setValue(V value) {
			return put(key, value);
		}

	}

	protected static final int EOF = 256; 
	
	protected static class Node {
		protected Node[] children = new Node[257];
		protected int count;
		protected Node parent;
		
		public Node(Node parent) {
			this.parent = parent;
		}
		
		public Node getParent() {
			return parent;
		}

		public Node getChild(int ch) {
			return children[ch];
		}

		public void putChild(int ch, Node n) {
			if (n==null) throw new NullPointerException();
			if (children[ch]==null) count++;
			children[ch] = n;
		}

		public Node aquireChild(int ch) {
			Node n = getChild(ch);
			
			if (n==null) {
				n = new Node(this);
				this.putChild(ch, n);
			}
			
			return n;
		}

		public Node removeChild(int ch) {
			if (children[ch]!=null) count--;
			
			Node n = children[ch]; 
			children[ch] = null;
			return n;
		}

		public int getChildCount() {
			return count;
		} 
	}
	
	protected static class EndNode<V> extends Node {
		protected V value;
		public EndNode(Node parent, V v) {
			super(parent);
			this.value = v;
		}
	}
	
	
	protected Node root;
	protected int size = 0;
	protected int mod = 0;
	
	public Trie() {
		root = new Node(null);
	}
	
	
	public V put(byte[] key, V value) {
		int c = key.length;
		Node n = root;
		
		for (int i = 0; i < c; i++) {
			byte ch = key[i];
			Node cn = n.aquireChild(ch);
			
			n = cn;
		}
		
		V old = null;
		EndNode<V> en = (EndNode<V>)n.getChild(EOF);
		if (en==null) {
			en = new EndNode<V>(n, value);
			n.putChild(EOF, en);
			size ++;
		}
		else {
			old = en.value;
			en.value = value;
		}
		
		mod++;
		return old;
	}
	
	public V get(byte[] key) {
		int c = key.length;
		Node n = root;
		
		for (int i = 0; i < c; i++) {
			byte ch = key[i];
			Node cn = n.getChild(ch);
			if (cn==null) return null;
				
			n = cn;
		}
		
		EndNode<V> en = (EndNode<V>)n.getChild(EOF);
		if (en==null) return null;
		
		return en.value;
	}
	
	public V remove(byte[] key) {
		int c = key.length;
		Node n = root;
		
		for (int i = 0; i < c; i++) {
			byte ch = key[i];
			Node cn = n.getChild(ch);
			if (cn==null) return null;
				
			n = cn;
		}
		
		EndNode<V> en = (EndNode<V>)n.getChild(EOF);
		if (en==null) return null;

		size--;
		V value = en.value;
		
		//------------------------------------------
		n.removeChild(EOF);
		
		int i = key.length;
		while (i>0 && n.getChildCount()==0) {
			Node p = n.getParent();
			if (p==null) break;
			
			i--;
			byte ch = key[i];
			p.removeChild(ch);
			
			n = p;
		}
		
		mod++;
		return value;
	}

	public void clear() {
		mod++;
		root = new Node(null);
	}

	public boolean containsKey(Object key) {
		return get(key) != null;
	}

	public boolean containsValue(Object value) {
		return values().contains(value);
	}

	protected Set<Map.Entry<byte[], V>> entrySet = null;
	public Set<Map.Entry<byte[], V>> entrySet() {
		if (entrySet==null) entrySet = new AbstractSet<Entry<byte[],V>>() {

			@Override
			public boolean remove(Object key) {
				return Trie.this.remove(key) != null;
			}
			
			@Override
			public boolean contains(Object key) {
				return Trie.this.containsKey(key);
			}
			
			@Override
			public Iterator<Map.Entry<byte[], V>> iterator() {
				final Iterator<byte[]> it = Trie.this.keySet().iterator();
				return new Iterator<Map.Entry<byte[], V>>() {

					public boolean hasNext() {
						return it.hasNext();
					}

					public Map.Entry<byte[], V> next() {
						byte[] key = it.next();
						V value = Trie.this.get(key);
						return new MyEntry(key, value);
					}

					public void remove() {
						it.remove();
					}
				};
			}

			@Override
			public int size() {
				return Trie.this.size();
			}
			
		};
		
		return entrySet;
	}

	protected Collection<V> values = null;
	public Collection<V> values() {
		if (values==null) values = new AbstractCollection<V>() {
			public void clear() {
				Trie.this.clear();
			}

			public int size() {
				return Trie.this.size();
			}

			@Override
			public Iterator<V> iterator() {
				final Iterator<byte[]> it = Trie.this.keySet().iterator();
				return new Iterator<V>() {

					public boolean hasNext() {
						return it.hasNext();
					}

					public V next() {
						byte[] key = it.next();
						return Trie.this.get(key);
					}

					public void remove() {
						it.remove();
					}
				};
			}
		};
		
		return values;
	}

	protected Set<byte[]> keySet = null;
	public Set<byte[]> keySet() {
		if (keySet==null) keySet = new AbstractSet<byte[]>() {

			@Override
			public Iterator<byte[]> iterator() {
				return new KeyIterator();
			}

			@Override
			public int size() {
				return Trie.this.size();
			}
			
		};
		
		return keySet;
	}

	public V get(Object key) {
		if (key instanceof String) {
			try {
				key = ((String)key).getBytes("UTF-8");
			} catch (UnsupportedEncodingException e) {
				throw new Error("UTF-8 not supported", e);
			}
		}
		
		return get((byte[])key);
	}

	public boolean isEmpty() {
		return size() == 0;
	}

	public void putAll(Map<? extends byte[], ? extends V> m) {
		for (Map.Entry<? extends byte[], ? extends V> e: m.entrySet()) {
			put(e.getKey(), e.getValue());
		}
	}

	public V remove(Object key) {
		if (key instanceof String) {
			try {
				key = ((String)key).getBytes("UTF-8");
			} catch (UnsupportedEncodingException e) {
				throw new Error("UTF-8 not supported", e);
			}
		}
		
		return remove((byte[])key);
	}

	public int size() {
		return size;
	}
	
	public static void main(String[] args) {
		Trie t = new Trie();
	}
}
