package de.brightbyte.data;

import java.util.Arrays;
import java.util.Iterator;

public class ChunkyBitSet implements Iterable<Integer> {
	protected static final int BITS_PER_UNIT = 64;

	private int chunkSize = 1024;
	private long[][] chunks = new long[1024][];
	private int size = 0;
	
	public ChunkyBitSet() {
		this(1024);
	}
	
	public ChunkyBitSet(int chunkSize) {
		this.chunkSize = chunkSize;
	}
	
	public int size() {
		return size;
	}
	
	public boolean isEmpty() {
		return size() == 0;
	}

	public void clear() {
		chunks = new long[1024][];
		size = 0;
	}

	protected long[] getChunk(int x, boolean create) {
		if (x>=chunks.length) {
			if (!create) return null;
			
			long n = ((long)chunks.length * 3 / 2); // grow by 50%
			if (n>Integer.MAX_VALUE) n = Integer.MAX_VALUE; 
			
			long[][] ch = new long[(int)n][];
			System.arraycopy(chunks, 0, ch, 0, chunks.length);
			chunks = ch;
		}
		
		if (chunks[x]==null && create) {
			chunks[x] = new long[chunkSize];
		}
		
		return chunks[x];
	}
	
	public boolean get(int i) {
		int b = i / BITS_PER_UNIT;
		int c = b / chunkSize;
		b -= c * chunkSize;
		
		i -= c*BITS_PER_UNIT*chunkSize;
		i -= b*BITS_PER_UNIT;
		
		if (i<0) throw new RuntimeException("bad bit index: "+i);
		if (i>=BITS_PER_UNIT) throw new RuntimeException("bad bit index: "+i);
		long m = 1L << i;
		
		long[] chunk = getChunk(c, false);
		if (chunk==null) return false;
		
		return (chunk[b] & m) != 0; 
	}
	
	public boolean containsAny(int[] a, int ofs, int len) {
		//TODO: optimize for sorted array...
		
		int top = ofs + len;
		for (int i=ofs; i<top; i++) {
			if (get(a[i])) return true;
		}
		
		return false;
	}
	
	public boolean containsAny(IntList a) {
		return containsAny(a.data, 0, a.size);
	}
	
	public boolean containsAny(Iterable<Integer> a) {
		for (int i: a) {
			if (get(i)) return true;
		}
		
		return false;
	}
	
	public int addAll(ChunkyBitSet a) {
		if (chunkSize!=a.chunkSize) {
			return addAll((Iterable<Integer>)a);
		}
		
		if (chunks.length < a.chunks.length) {
			long[][] ch = new long[a.chunks.length][];
			System.arraycopy(chunks, 0, ch, 0, chunks.length);
		}
		
		int c = 0;
		
		for (int i=0; i<chunks.length; i++) {
			long[] ach = a.chunks[i];
			if (ach==null) continue;
			
			long[] ch = chunks[i];
			if (ch==null) {
				chunks[i] = ch = new long[chunkSize];
			}
			
			for (int j=0; j<ch.length; j++) {
				long n = ach[j];
				if (n==0) continue;
					
				long b = ch[j];
				long v = b | n;
				
				if (v!=b) {
					c+= bitCount(n & ~b);
					ch[j] = v;
				}
			}
		}
		
		size += c;
		return c;
	}
	
	private int bitCount(long bits) {
		int c = 0;
		if (bits==0) return 0;
		
		long m = 1;
		for (int i = 0; i<BITS_PER_UNIT; i++) {
			if ((bits & m)!=0) c++; 
			m = m << 1;
		}
		
		return c;
	}
	
	public int addAll(Iterable<Integer> a) {
		int c = 0;
		for (int i: a) if (add(i)) c++;
		return c;
	}
	
	public int addAll(int[] a) {
		int c = 0;
		for (int i: a) if (add(i)) c++;
		return c;
	}
	
	public int addAll(IntList a) {
		int c = 0;
		int n = a.size();
		for (int i=0; i<n; i++) if (add(a.getInt(i))) c++;
		return c;
	}
	
	public int removeAll(Iterable<Integer> a) {
		int c = 0;
		for (int i: a) if (remove(i)) c++;
		return c;
	}
	
	public int removeAll(int[] a) {
		int c = 0;
		for (int i: a) if (remove(i)) c++;
		return c;
	}
	
	public void setAll(Iterable<Integer> a, boolean v) {
		for (int i: a) set(i, v);
	}
	
	public void setAll(int[] a, boolean v) {
		for (int i: a) set(i, v);
	}
	
	public boolean set(int i, boolean v) {
		return set(i, v?1:-1);
	}
	
	private boolean set(int i, int mode) {
		int b = i / BITS_PER_UNIT;
		int c = b / chunkSize;
		b -= c * chunkSize;
		
		i -= c*BITS_PER_UNIT*chunkSize;
		i -= b*BITS_PER_UNIT;
		
		if (i<0) throw new RuntimeException("bad bit index: "+i);
		if (i>=BITS_PER_UNIT) throw new RuntimeException("bad bit index: "+i);
		long m = 1L << i;
		
		long[] chunk = getChunk(c, mode>=0);
		if (chunk==null) return false;
		
		boolean old = (chunk[b] & m) != 0;
		
		if (mode==0) mode = old?-1:1; //flip
		
		if (mode==1) {
			if (!old) {
				chunk[b] |= m;
				size++;
			}
		}
		else {
			if (old) {
				chunk[b] &= ~m;
				size--;
			}
		}
		
		return old;
	}
	
	public boolean add(int i) {
		boolean old = set(i, 1);
		return !old;
	}

	public boolean remove(int i) {
		boolean old = set(i, -1);
		return old;
	}

	public void flip(int i) {
		set(i, 0);
	}
	
	@Override
	public int hashCode() {
		int hash = 0;
		
		final int prime = 29;
		
		for (int i=0; i<chunks.length; i++) {
			long[] ch = chunks[i]; 
			if (ch!=null && !isEmpty(ch)) hash ^= Arrays.hashCode(ch) * i * prime;
		}
		
		return hash;
	}
	
	@Override
	public boolean equals(Object obj) {
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		
		final ChunkyBitSet other = (ChunkyBitSet) obj;
		int n = chunks.length > other.chunks.length ? chunks.length : other.chunks.length;
		
		for (int i=0; i<n; i++) {
			long[] a = i >= chunks.length ? null : chunks[i];
			long[] b = i >= other.chunks.length ? null : other.chunks[i];
			
			if (a==null) {
				if (b!=null) {
					if (!isEmpty(b)) return false;
				}
			}
			else if (b==null) {
				if (!isEmpty(a)) return false;
			}
			else {
				if (!Arrays.equals(a, b)) return false;
			}
		}
		
		return true;
	}

	protected static boolean isEmpty(long[] a) {
		for(long x: a) {
			if (x!=0) return false;
		}
		
		return true;
	}
	
	protected int successor(int idx) {
		idx++;
		
		int i = idx;
		int b = i / BITS_PER_UNIT;
		int c = b / chunkSize;
		b -= c * chunkSize;
		
		i -= c*BITS_PER_UNIT*chunkSize;
		i -= b*BITS_PER_UNIT;
		
		if (i<0) throw new RuntimeException("bad bit index: "+i);
		if (i>=BITS_PER_UNIT) throw new RuntimeException("bad bit index: "+i);
		
		long[] chunk; 
		while (c<chunks.length) {
			chunk = getChunk(c, false);
			
			if (chunk!=null) {
				while (b<chunkSize) {
					if (chunk[b]>0) {
						while (i<BITS_PER_UNIT) {
							long m = 1L << i;
							if ((chunk[b] & m) != 0) {
								return idx;
							}
							
							i++;
							idx++;
						}
					}
					
					i = 0;
					b++;
					
					idx = c * chunkSize * BITS_PER_UNIT + b * BITS_PER_UNIT;
				}
			}
			
			i = 0;
			b = 0;
			c++;
			
			idx = c * chunkSize * BITS_PER_UNIT;
		}
		
		return -1;
	}

	protected class CBSIterator implements Iterator<Integer> {
		private int index = -1;
		
		public CBSIterator() {
			next();
		}

		public boolean hasNext() {
			return index >= 0;
		}

		public Integer next() {
			return nextInt();
		}

		public int nextInt() {
			int i = index;
			index = successor(i);
			return i;
		}

		public void remove() {
			throw new UnsupportedOperationException(); //TODO 
		}
		
	}
	
	public Iterator<Integer> iterator() {
		return new CBSIterator();
	}

	public int[] toIntArray() {
		int[] a = new int[size()];
		CBSIterator it = (CBSIterator)iterator();
		
		int i = 0;
		while (it.hasNext()) {
			a[i++] = it.nextInt();
		}
		
		return a;
	}
	
	//TODO: efficient hashcode, equals!
	//TODO: iterate true/false over index
	//TODO: iterate indexes of set bits (or of  clear bits)
	//TODO: find next/previsous set/clear bit
	//TODO: find last set bit
	//TODO: union, intersect, subtract
}
