package de.brightbyte.data;


public class IntBiRelation extends IntRelation {
	protected IntRelation reverse;
	
	//when removing more than n values, use ripple shift instead of repeated array-copy
	//TODO: get good initial value.... may depend on size too?
	protected int shiftThreshold = 4;  
	
	public IntBiRelation() {
		this(1024);
	}
	
	public IntBiRelation(int capacity) {
		super(capacity);
		reverse = new IntRelation(capacity);
	}
	
	public IntBiRelation(IntRelation rel) {
		this(rel.size());
		putAll(rel);
	}
	
	public IntBiRelation(IntList k, IntList v) {
		this(k.size());
		putAll(k, v);
	}
	
	public IntBiRelation(int[] k, int[] v, int ofs, int len) {
		this(len);
		putAll(k, v, ofs, len);
	}
	
	@Override
	public boolean containsKey(int id) {
		return super.containsKey(id);
	}

	public boolean containsValue(int v) {
		return reverse.containsKey(v);
	}

	@Override
	public void ensureCapacity(int cap) {
		super.ensureCapacity(cap);
		reverse.ensureCapacity(cap);
	}

	@Override
	public boolean equals(Object obj) {
		return super.equals(obj);
	}

	@Override
	public int[] get(int id) {
		return super.get(id);
	}

	public int[] getKeys(int v) {
		return reverse.get(v);
	}

	@Override
	public IntList getKeys() {
		return super.getKeys();
	}

	public IntList getValues() {
		return reverse.getKeys();
	}

	@Override
	public IntList getSinks() {
		return super.getSinks();
	}

	@Override
	public int hashCode() {
		return super.hashCode();
	}

	@Override
	public Iterable<Integer> keys() {
		return super.keys();
	}

	public Iterable<Integer> values() {
		return reverse.keys();
	}

	@Override
	public void put(int id, int v) {
		super.put(id, v);
		reverse.put(v, id);
	}

	@Override
	public void put(int id, int[] v, int ofs, int len) {
		super.put(id, v, ofs, len);
		
		for (int i=0; i<len; i++) {
			reverse.put(v[ofs+i], id);
		}
	}

	@Override
	public void put(int id, int[] v) {
		put(id, v, 0, v.length);
	}

	@Override
	public int remove(int id, int v) {
		int r = super.remove(id, v);
		if (r>0) {
			reverse.remove(v, id);
		}
		return r;
	}

	@Override
	public int remove(int id) {
		int[] v = super.get(id);
		if (v.length>0) {
			super.remove(id);
			if (v.length>shiftThreshold ) {
				reverse.removeValue(id);
			}
			else {
				for (int x: v) {
					reverse.remove(x, id);
				}
			}
		}
		return v.length;
	}

	@Override
	public int size() {
		return super.size();
	}
	
	@Override
	public int removeValue(int v) {
		int[] k = reverse.get(v);
		if (k.length>0) {
			reverse.remove(v);
			
			if (k.length>shiftThreshold ) {
				super.removeValue(v);
			}
			else {
				for (int x: k) {
					super.remove(x, v);
				}
			}
		}
		return k.length;
	}	

	@Override
	public String toString() {
		return super.toString();
	}

	public int pruneSinks(IntList start) {
		return doPruneSinks(new IntList(start));
	}
	
	public int pruneSinks(int[] start) {
		return doPruneSinks(new IntList(start));
	}
	
	public int pruneSinks(int start) {
		IntList todo = new IntList();
		todo.add(start);
		return doPruneSinks(todo);
	}
	
	/*
	protected int doPruneSinks(IntList todo) {
		int c = 0;
		while (todo.size()>0) {
			int id = todo.remove(todo.size()-1); //TODO: pop first, use ring-buffer! avoids some more redundancy.
			
			if (!containsKey(id)) {
				int[] p = getKeys(id);
				todo.addAll(todo.size(), p);
				c+= removeValue(id);
				//System.out.println("#"+c+": removed "+id+", todo: "+todo+", remaining: "+this);
			}
		}
		
		return c;
	}
	*/
	
	protected int doPruneSinks(IntList todo) {
		int c = 0;
		while (todo.size()>0) {
			int id = todo.remove(todo.size()-1); //TODO: pop first, use ring-buffer! avoids some more redundancy.
			
			if (!containsKeyIgnoring(id, maskMarker)) {
				int[] p = getKeys(id);
				todo.addAll(todo.size(), p);
				c+= maskValue(p, id);
				//System.out.println("#"+c+": removed "+id+", todo: "+todo+", remaining: "+this);
			}
		}
		
		flushMasked();
		
		return c;
	}
	
	private IntList masked = new IntList();
	private int maskMarker = Integer.MIN_VALUE;
	
	protected int maskValue(int[] ids, int value) {
		int c = 0;
		masked.add(value);
		
		for (int id: ids) {
			c += super.replace(id, value, maskMarker);
		}
		return c;
	}

	protected int flushMasked() {
		masked.size();
		
		super.removeValue(maskMarker);
		int c = reverse.remove(masked.toIntArray());
		
		masked.clear();
		return c;
	}
	
	@Override
	protected int stripSinks() {  //O(n*log(n))
		int i = 0;
		int j = 0;
		
		IntList rm = new IntList();

		while (i<size) {
			if (!containsKey(val[i])) {
				rm.add(val[i]);
				i++;
				continue;
			}
			
			ids[j] = ids[i]; 
			val[j] = val[i];
			
			i++;
			j++;
		};

		reverse.remove(rm.toIntArray());
		
		size = j;
		return i - j;
	}

	@Override
	public IntList getSources() { 
		return reverse.getSinks();
	}

	public void putAll(IntBiRelation rel) {
		super.putAll(rel.ids, rel.val, 0, rel.size, size > 0);
		reverse.putAll(rel.reverse.ids, rel.reverse.val, 0, rel.reverse.size, size > 0);
	}

	@Override
	protected void putAll(int[] k, int[] v, int ofs, int len, boolean sort) {
		super.putAll(k, v, ofs, len, sort);
		reverse.putAll(v, k, ofs, len, true);
	}

	public int getShiftThreshold() {
		return shiftThreshold;
	}

	public void setShiftThreshold(int shiftThreshold) {
		this.shiftThreshold = shiftThreshold;
	}

	@Override
	public boolean containsKeyIgnoring(int id, int ignoring) {
		return super.containsKeyIgnoring(id, ignoring);
	}

	@Override
	public int countIgnoring(int id, int ignoring) {
		return super.countIgnoring(id, ignoring);
	}

	@Override
	public int replace(int id, int v, int n) {
		int c = super.replace(id, v, n);
		
		if (c>0) {
			reverse.remove(v);
			reverse.put(n, id);
		}
		
		return c;
	}
	
	@Override
	public int fill(int id, int n) {
		int[] v = super.get(id);
		if (v.length==0) return 0;
		
		super.fill(id, n);

		//XXX: this is pretty slow
		for (int x: v) {
			reverse.remove(x);
			reverse.put(n, id);
		}
		
		return v.length;
	}

	@Override
	public int remove(int[] del) {
		int c = super.remove(del);
		
		//XXX: this is pretty slow
		for (int i= 0; i<del.length; i++) {
			reverse.removeValue(del[i]);
		}
		
		return c;
	}

	@Override
	public int count(int id) {
		return super.count(id);
	}
	
	public int countKeys(int v) {
		return reverse.count(v);
	}
	
	
	/*public IntList getSources() { //O(n*log(n))
		IntList ends = new IntList();
		
		int[] vv = new int[size];
		System.arraycopy(val, 0, vv, 0, size);
		
		Arrays.sort(vv);
		for (int i=0; i<size; i++) {
			int id = ids[i];
			if (find(id, vv, size)<0) {
				ends.add(id);
			}
		}
		
		//FIXME: make unique!
		return ends;
	}*/
	
	
}
