package de.brightbyte.data;

import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;



public class MapLabeledMatrix<K, L> implements LabeledMatrix<K, L> {
	protected Map <Pair<K, L>, Double> values = new HashMap<Pair<K, L>, Double>();
	protected boolean symmetric;
	
	public MapLabeledMatrix() {
		this(false);
	}

	public MapLabeledMatrix(boolean symmetric) {
		this.symmetric = symmetric;
	}

	public MapLabeledMatrix(Map<?extends Pair<K, L>, Double> values, boolean symmetric) {
		this(symmetric);
		if (values!=null) this.values.putAll(values);
	}
	
	protected Pair<K, L> makeKey(K a, L b) {
		if (symmetric) return (Pair<K, L>)new ReflexiveKeyPair(a, b);
		else return new Pair<K, L>(a, b);
	}

	public void add(K a, L b, double w) {
		add(makeKey(a, b), w);
	}

	public void add(Pair<K, L> pair, double w) {
		Double v = values.get( pair );
		double d = v == null ? 0 : v;
		
		d += w; 
		values.put( pair, d );
	}

	public void merge(K a, L b, double w) {
		merge(makeKey(a, b), w);
	}

	public void merge(Pair<K, L> pair, double w) {
		Double v = values.get( pair );
		double d = v == null ? 0 : v;
		
		d += (1-d) * w; //merge probability ("OR")
		values.put( pair, d );
	}

	public void set(K a, L b, double w) {
		set(makeKey(a, b), w);
	}

	public void set(Pair<K, L> pair, double w) {
		values.put( pair, w );
	}

	/*
	public void add(K[] keys, double w) {
		for (int i = 0; i<keys.length; i++) {
			K a = keys[i];
			
			for (int j = 0; j < i; j++) { //skip "lower triangle" of matrix
				K b = keys[j];
				if (a.equals(b)) continue; //note: redundant if keys are unique in array 
				
				add(a, b, w);
			}
		}
	}
	*/
	
	public boolean contains(K a, L b) {
		return contains( makeKey(a, b) );
	}
	
	public double get(K a, L b) {
		//if (a.equals(b)) return 1;
		return get( makeKey(a, b) );
	}
	
	public boolean contains(Pair<K, L> pair) {
		return values.containsKey( pair );
	}
	
	public double get(Pair<K, L> pair) {
		//if (pair.getA().equals(pair.getB())) return 1;
		Double v = values.get( pair );
		return v == null ? 0 : v;
	}

	public void dump(PrintWriter output) {
		for (Map.Entry<Pair<K, L>, Double> e : values.entrySet()) {
			output.println(e.getKey()+": "+e.getValue());
		}
	}

	public void dump(PrintStream output) {
		for (Map.Entry<Pair<K, L>, Double> e : values.entrySet()) {
			output.println(e.getKey()+": "+e.getValue());
		}
	}

	public LabeledVector<K> column(L c) {
		MapLabeledVector<K> v = new MapLabeledVector<K>();
		for (Pair<K, L> k : values.keySet()) {
			if (k.getB().equals(c)) v.set(k.getA(), get(k));
			else if (symmetric && k.getA().equals(c)) v.set((K)k.getB(), get(k));
		}
		
		return v;
	}

	public LabeledVector<L> row(K r) {
		MapLabeledVector<L> v = new MapLabeledVector<L>();
		for (Pair<K, L> k : values.keySet()) {
			if (k.getA().equals(r)) v.set(k.getB(), get(k));
			else if (symmetric && k.getB().equals(r)) v.set((L)k.getA(), get(k));
		}
		
		return v;
	}
	

	public void removeColumn(L c) {
		Iterator<Pair<K, L>> it = values.keySet().iterator();
		while (it.hasNext()) {
			Pair<K, L> k = it.next();
			if (k.getA().equals(c)) it.remove();
			else if (symmetric && k.getB().equals(c)) it.remove();
		}
	}

	public void removeRow(K r) {
		Iterator<Pair<K, L>> it = values.keySet().iterator();
		while (it.hasNext()) {
			Pair<K, L> k = it.next();
			if (k.getB().equals(r)) it.remove();
			else if (symmetric && k.getA().equals(r)) it.remove();
		}
	}
	

	public Iterable<L> columns() {
		Set<L> keys = new HashSet<L>();
		for (Pair<K, L> k : values.keySet()) {
			keys.add(k.getB());
			if (symmetric) keys.add((L)k.getA()); 
		}
		return keys;
	}

	public Iterable<K> rows() {
		Set<K> keys = new HashSet<K>();
		for (Pair<K, L> k : values.keySet()) {
			keys.add(k.getA());
			if (symmetric) keys.add((K)k.getB()); 
		}
		return keys;
	}

	public int getColumnCount() {
		return ((Collection<L>)columns()).size();
	}

	public int getRowCount() {
		return ((Collection<K>)rows()).size();
	}

	public void trim(double min, double max) {
		Iterator<Map.Entry<Pair<K, L>, Double>> it = values.entrySet().iterator();
		while (it.hasNext()) {
			Map.Entry<Pair<K, L>, Double> e = it.next();
			double v = e.getValue();
			if (v < min || v > max) it.remove();  
		}
	}

	public boolean isSymmetric() {
		return symmetric;
	}
	
	
}
