package de.brightbyte.data;

import java.util.Arrays;
import java.util.Iterator;

public class IntRelation {

	protected int[] ids;
	protected int[] val;
	protected int size;
	
	public IntRelation() {
		this(1024);
	}
	
	public IntRelation(int capacity) {
		ids = new int[capacity];
		val = new int[capacity];
	}
	
	public IntRelation(IntRelation rel) {
		this(rel.size());
		putAll(rel);
	}

	public IntRelation(IntList k, IntList v) {
		this(k.size());
		putAll(k, v);
	}

	public IntRelation(int[] k, int[] v, int ofs, int len) {
		this(len);
		putAll(k, v, ofs, len);
	}

	public void ensureCapacity(int cap) {
		if (ids.length<cap) {
			long c = (long)size*3/2;
			if (c<cap) c = cap;
			if (c>Integer.MAX_VALUE) c = Integer.MAX_VALUE;
			
			int[] s = new int[(int)c];
			System.arraycopy(ids, 0, s, 0, size);
			ids = s;

			int[] v = new int[(int)c];
			System.arraycopy(val, 0, v, 0, size);
			val = v;
		}
	}
	
	protected void shift(int pos, int ofs) {
		if (size+ofs>=ids.length) ensureCapacity(size + ofs);

		if (pos>=size) {
			size += ofs;
			return;
		}
		
		int from = ofs < 0 ? pos - ofs : pos ;
		int to   = ofs < 0 ? pos : pos + ofs ;
		
		System.arraycopy(ids, from, ids, to, size - from);
		System.arraycopy(val, from, val, to, size - from);

		size += ofs;
	}
	
	protected int find(int id) {
		if (size==0) {
			return -1;
		}
		else if (id>ids[size-1]) {
			return -size -1; 
		}
		
		return find(id, ids, size); 
	}
	
	protected void sort() {
		sort(ids, val, 0, size);
	}

	protected void sort(int off, int len) {
		sort(ids, val, off, len);
	}

	/**
	 * Adopted from Arrays.sort - sorts x and keeps the entries in y paired up. 
	 */
	protected static void sort(int[] x, int[] y, int off, int len) {
		// Insertion sort on smallest arrays
		if (len < 7) {
		    for (int i=off; i<len+off; i++)
		    	for (int j=i; j>off && x[j-1]>x[j]; j--)
		    		swap(x, y, j, j-1);
		    return;
		}

		// Choose a partition element, v
		int m = off + (len >> 1);       // Small arrays, middle element
		if (len > 7) {
		    int l = off;
		    int n = off + len - 1;
		    if (len > 40) {        // Big arrays, pseudomedian of 9
			int s = len/8;
			l = med3(x, l,     l+s, l+2*s);
			m = med3(x, m-s,   m,   m+s);
			n = med3(x, n-2*s, n-s, n);
		    }
		    m = med3(x, l, m, n); // Mid-size, med of 3
		}
		int v = x[m];

		// Establish Invariant: v* (<v)* (>v)* v*
		int a = off, b = a, c = off + len - 1, d = c;
		while(true) {
		    while (b <= c && x[b] <= v) {
			if (x[b] == v)
			    swap(x, y, a++, b);
			b++;
		    }
		    while (c >= b && x[c] >= v) {
			if (x[c] == v)
			    swap(x, y, c, d--);
			c--;
		    }
		    if (b > c)
			break;
		    swap(x, y, b++, c--);
		}

		// Swap partition elements back to middle
		int s, n = off + len;
		s = Math.min(a-off, b-a  );  vecswap(x, y, off, b-s, s);
		s = Math.min(d-c,   n-d-1);  vecswap(x, y, b,   n-s, s);

		// Recursively sort non-partition-elements
		if ((s = b-a) > 1)
		    sort(x, y, off, s);
		if ((s = d-c) > 1)
		    sort(x, y, n-s, s);
	}
	
    /**
     * Swaps x[a] with x[b].
     */
    private static void swap(int x[], int y[], int a, int b) {
		int t = x[a];
		x[a] = x[b];
		x[b] = t;

		t = y[a];
		y[a] = y[b];
		y[b] = t;
    }

    /**
     * Swaps x[a .. (a+n-1)] with x[b .. (b+n-1)].
     */
    private static void vecswap(int x[], int y[], int a, int b, int n) {
    	for (int i=0; i<n; i++, a++, b++)
    		swap(x, y, a, b);
    }

    /**
     * Returns the index of the median of the three indexed integers.
     */
    private static int med3(int x[], int a, int b, int c) {
    	return (x[a] < x[b] ?
    			(x[b] < x[c] ? b : x[a] < x[c] ? c : a) :
    				(x[b] > x[c] ? b : x[a] > x[c] ? c : a));
    }
	
	protected static int find(int id, int[] ids, int size) {
		if (size==0) return -1;
		
		int l = 0;
		int r = size-1;
		int i;
		
		if ( id==ids[r] ) {
			i = l = r; 
		}
		else i = (l + r) / 2;
		
		while (l <= r) {
		    int x = ids[i];

		    if (x < id) {
		    	l = i + 1;
		    }
		    else if (x > id) {
		    	r = i - 1;
		    }
		    else {
		    	while (i>0 && ids[i-1]==id) i--;
		    	return i; 
		    }

		    i = (l + r) / 2;
		}
		
		return -(l + 1);  
	}
	
	private static final int[] none = new int[0];
	
	public int[] get(int id) {
		int i = find(id);
		if (i<0) return none;
		
		int j = i+1;
		while (j<size && ids[j]==id) j++;
		
		int sz = j - i;
		int[] a = new int[sz];
		
		System.arraycopy(val, i, a, 0, sz);
		return a;
	}
	
	public int count(int id) {
		int i = find(id);
		if (i<0) return 0;
		
		int j = i+1;
		while (j<size && ids[j]==id) j++;
		
		return j - i;
	}
	
	public int countIgnoring(int id, int ignoring) {
		int i = find(id);
		if (i<0) return 0;
		
		int j = i;
		int c = 0;
		while (j<size && ids[j]==id) {
			if (val[j]!=ignoring) c++;
			j++;
		}
		
		return c;
	}
		
	public boolean containsKey(int id) {
		int i = find(id);
		return i>=0;
	}
	
	public boolean containsKeyIgnoring(int id, int ignoring) {
		int i = find(id);
		if (i<0) return false;

		int j = i;
		int c = 0;
		while (j<size && ids[j]==id) {
			if (val[j]!=ignoring) c++;
			j++;
		}
		
		return c > 0;
	}
	
	public void putAll(IntRelation rel) {
		putAll(rel.ids, rel.val, 0, rel.size, size > 0);
	}
	
	public void putAll(IntList k, IntList v) {
		putAll(k.data, v.data, 0, k.size);
	}
	
	public void putAll(int[] k, int[] v, int ofs, int len) {
		putAll(k, v, ofs, len, true);
	}
	
	protected void putAll(int[] k, int[] v, int ofs, int len, boolean sort) {
		if (k.length < ofs+len) throw new IllegalArgumentException("out of range: "+k.length+" < "+(ofs+len));
		if (v.length < ofs+len) throw new IllegalArgumentException("out of range: "+v.length+" < "+(ofs+len));

		ensureCapacity(size+len);
		
		System.arraycopy(k, ofs, ids, size, len);
		System.arraycopy(v, ofs, val, size, len);
		
		size+= len;
		
		if (sort) sort();
	}
	
	public void put(int id, int[] v) {
		put(id, v, 0, v.length);
	}
	
	public void put(int id, int[] v, int ofs, int len) {
		//TODO: keep val sorted too!
		
		if (size+len>=ids.length) {
			ensureCapacity(size + len);
		}

		int i = find(id);
		if (i<0) i = -1 +1;
		
		shift(i, len);
		
		for (int j= 0; j<len; j++) {
			ids[i+j] = id; 
			val[i+j] = v[ofs+j];
		}
	}

	public void put(int id, int v) {
		int i = find(id);
		if (i<0) i = -i -1;
		
		shift(i, 1);
		
		ids[i] = id; 
		val[i] = v; 
	}
	
	public int  remove(int id) {
		int i = find(id);
		if (i<0) return 0;

		int j = i+1;
		while (j<size && ids[j]==id) j++;
		
		int sz = j - i;
		shift(i, -sz);
		
		return sz;
	}
	
	public int remove(int id, int v) {
		int i = find(id);
		if (i<0) return 0;
		
		int c = 0;

		do {
			if (val[i]==v) {
				shift(i, -1);
				c++;
			}
			else i++;
		} while (i<size && ids[i]==id);
		
		return c;
	}
	
	public int fill(int id, int n) {
		int i = find(id);
		if (i<0) return 0;
		
		int c = 0;

		do {
			val[i++]= n;
			c++;
		} while (i<size && ids[i]==id);
		
		return c;
	}
	
	public int replace(int id, int v, int n) {
		int i = find(id);
		if (i<0) return 0;
		
		int c = 0;

		do {
			if (val[i]==v) {
				val[i] = n;
				c++;
			}
			
			i++;
		} while (i<size && ids[i]==id);
		
		return c;
	}
	
	public int pruneSinks() { //O(n^2*log(n)) worst case, O(n*log^2(n)) for "good" graphs
		int n = 0;
		int c;
		do {
			c = stripSinks();
			n+= c;
		} while (c>0);
		
		return n;
	}
	
	protected int stripSinks() {  //O(n*log(n))
		int i = 0;
		int j = 0;
		
		while (i<size) {
			if (!containsKey(val[i])) {
				i++;
				continue;
			}
			
			ids[j] = ids[i]; 
			val[j] = val[i];
			
			i++;
			j++;
		};

		size = j;
		return i - j;
	}
	
	public int size() {
		return size;
	}

	public Iterable<Integer> keys() {
		final Iterator<Integer> it = new Iterator<Integer>() {
			int prev=-1;
			int next=0;
			
			public void remove() {
				if (prev>=0) next -= IntRelation.this.remove(ids[prev]);
			}
		
			public Integer next() {
				prev = next;
				int id = ids[prev];
				while (next<size && ids[next]==id) next++;
				
				return id;
			}
		
			public boolean hasNext() {
				return next>=0 && next < size;
			}
		};
		
		return new Iterable<Integer>() {
			public Iterator<Integer> iterator() {
				return it;
			}
		};
	}

	public IntList getKeys() {
		IntList list = new IntList();
		
		int i = 0;
		while (i<size) {
			int k = ids[i];
			list.add(k);
			
			i++;
			while (i<size && k==ids[i]) i++;
		}
		
		return list;
	}
	
	public IntList getValues() {
		IntList list = new IntList(size);
		
		int i = 0;
		while (i<size) {
			int v = val[i++];
			list.add(v);
		}
		
		return list;
	}	
	
	public IntList getSinks() { //O(n*log(n))
		IntList ends = new IntList();
		
		for (int i=0; i<size; i++) {
			int v = val[i];
			if (!containsKey(v)) ends.add(v);
		}
		
		//FIXME: make unique!
		return ends;
	}
	
	public IntList getSources() { //O(n*log(n))
		IntList ends = new IntList();
		int[] vv = getValues().data;
		Arrays.sort(vv);
		
		for (int i=0; i<size;) {
			int id = ids[i];
			if (find(id, vv, size)<0) ends.add(id);
			while (i<size && id==ids[i]) i++;
		}
		
		return ends;
	}
	
	@Override
	public String toString() {
		StringBuilder s = new StringBuilder();
		s.append('[');
		for (int i=0; i<size; i++) {
			if (i>0) s.append(',').append(' ');
			s.append(ids[i]).append('-').append('>').append(val[i]);
		}
		
		s.append(']');
		
		return s.toString();
	}
	
	public int removeValue(int v) {
		int i = 0;
		int j = 0;

		while (i<size) {
			if (val[i]==v) {
				i++;
				continue;
			}
			
			ids[j] = ids[i]; 
			val[j] = val[i];
			
			i++;
			j++;
		};
		
		size = j;
		
		return i - j;
	}
	
	public int remove(int[] del) {
		if (del.length == 0) return 0;
		
		Arrays.sort(del);
		
		int i = 0;
		int j = 0;
		int c = 0;
		int n = 0;
		
		while (i<size) {
			while (n<del.length && del[n]<ids[i]) {
				n++;
			}

			//TODO: use bin search to bump i to the right spot!
			if (n<del.length && ids[i]==del[n]) {
				i++;
				c++;
				continue;
			}
			
			ids[j] = ids[i]; 
			val[j] = val[i];
			
			i++;
			j++;
		};
		
		size -= c;
		
		return c;
	}
}
