package de.brightbyte.data;

import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

//import de.brightbyte.ir.data.FeatureWeight;


public class MapLabeledVector<K> implements LabeledVector<K> {
	
	protected Map<K, Double> values = new HashMap<K, Double>();
	private double length = Double.NaN;
	private double total = Double.NaN;
	
	public MapLabeledVector() {
		//noop
	}

	/*
	public MapFeatureVector(FeatureWeight<K>[] weights) {
		this();
		adds(weights);
	}
	*/
	
	public MapLabeledVector(Map<K, ? extends Number> weights) {
		this();
		add(weights);
	}
	
	public MapLabeledVector(Set<K> labels, double weight) {
		this();
		add(labels, weight);
	}
	
	public void add(LabeledVector<K> vs) {
		for (K k : vs.labels()) {
			add(k, vs.get(k));
		}
	}
	
	public void add(Set<K> labels, double weight) {
		for (K k : labels) {
			add(k, weight);
		}
	}
	
	public void add(Map<K, ? extends Number> weights) {
		for (Map.Entry<K, ? extends Number> w : weights.entrySet()) {
			add(w.getKey(), w.getValue().doubleValue());
		}
	}
	
	/*
	public void add(FeatureWeight<K> w) {
		add(w.getFeature(), w.getWeight());
	}*/
	
	public void add(K key, double weight) {
		Double v = values.get(key);
		if (v!=null) weight += v;
		
		if (weight==0 && v!=null) values.remove(key);
		else values.put(key, weight);
		
		modified();
	}
	
	public void set(LabeledVector<K> vs) {
		for (K k : vs.labels()) {
			set(k, vs.get(k));
		}
	}
	
	public void set(Set<K> labels, double weight) {
		for (K k : labels) {
			set(k, weight);
		}
	}
		
	public void set(K key, double weight) {
		values.put(key, weight);
		
		modified();
	}
	
	protected void modified() {
		length = Double.NaN;
		total = Double.NaN;
	}
	
	public double get(K key) {
		Double v = values.get(key);
		return v == null ? 0 : v;
	}
	
	public Set<K> labels() {
		return values.keySet();
	}
	
	public int size() {
		return values.size();
	}

	public double total() {
		double t = 0;
		for (double v: values.values()) {
			t+= v;
		}
		
		return t;
	}

	//TODO: trim top x / bottom y
	
	public void remove(K key) {
		values.remove(key);
		modified();
	}
	
	public void dump(PrintWriter output) {
		for (Map.Entry<K, Double> e : values.entrySet()) {
			output.println(e.getKey()+": "+e.getValue());
		}
	}

	public void dump(PrintStream output) {
		for (Map.Entry<K, Double> e : values.entrySet()) {
			output.println(e.getKey()+": "+e.getValue());
		}
	}
	
	@Override
	public String toString() {
		return values.toString();
	}

	public Iterable<K> combinedLabels(LabeledVector<K> other) {
		HashSet<K> keys = new HashSet<K>(size()+other.size());
		keys.addAll(values.keySet());
		
		for (K feature: other.labels()) keys.add(feature);
		
		return keys;
	}

	public LabeledVector<K> difference(LabeledVector<K> v) {
		MapLabeledVector<K> x = new MapLabeledVector<K>();
		for (K feature: combinedLabels(v)) {
			double w1 = get(feature);
			double w2 = v.get(feature);
			x.add(feature, w1 - w2 );
		}
		
		return x;
	}

	public double distance(LabeledVector<K> v) {
		double d = 0;
		for (K feature: combinedLabels(v)) {
			double w1 = get(feature);
			double w2 = v.get(feature);
			double w = w1 - w2;
			d +=  w*w;
		}
		
		return Math.sqrt(d);
	}

	public LabeledVector<K> summ(LabeledVector<K> v) {
		MapLabeledVector<K> x = new MapLabeledVector<K>();
		for (K feature: combinedLabels(v)) {
			double w1 = get(feature);
			double w2 = v.get(feature);
			x.add(feature, w1 + w2 );
		}
		
		return x;
	}

	public double getLength() {
		if (!Double.isNaN(length)) return length;
		
		double w = 0;
		for (K f: labels()) {
			double d = get(f);
			w += d * d;
		}
		
		length = Math.sqrt(w);
		return length;
	}

	public double getTotal() {
		if (!Double.isNaN(total)) return total;
		
		double t = 0;
		for (K f: labels()) {
			t += get(f);
		}
		
		total = t;
		return total;
	}

	public double scalar(LabeledVector<K> v) {
		double d = 0;
		for (K feature: labels()) {
			double wa = get(feature);
			double wb = v.get(feature);
			
			d += wa * wb;
		}
		return d;
	}

	public LabeledVector<K> scaled(double scale) {
		MapLabeledVector<K> v = new MapLabeledVector<K>();
		for (K feature: labels()) {
			double w = get(feature);
			v.add(feature, w*scale);
		}
		
		return v;
	}
	

	public void trim(double min, double max) {
		if (min <= 0) min = 0;
		if (max <= 0) max = Double.MAX_VALUE;
		
		Iterator<Map.Entry<K, Double>> it = values.entrySet().iterator();
		while (it.hasNext()) {
			Map.Entry<K, Double> e = it.next();
			double v = e.getValue(); 
			if (v < min || v > max) it.remove();
		}
	}
	
}
