package de.brightbyte.xml;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.Writer;
import java.net.URL;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;

import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

import de.brightbyte.io.IOUtil;

public class XmlUtil {
	public static final Pattern pathSplitPattern = Pattern.compile("/");
	public static final XPathFactory xpathFactory = XPathFactory.newInstance();
	public static final XPath xpath = xpathFactory.newXPath();
	
	protected static final TransformerFactory poxTranformerFactory = TransformerFactory.newInstance();
	protected static final DocumentBuilderFactory poxBuilderFactory = DocumentBuilderFactory.newInstance();
	protected static DocumentBuilder poxBuilder = null;

	public static int[] getIntsAttribute(Element e, String name, int[] def) {
		String encodedInts[] = getStringsAttribute(e, name, null);
		if (encodedInts == null) return def;
		else {
			int results[] = new int[encodedInts.length];
			for (int i = 0; i < results.length; i++) {
				results[i] = new Integer(encodedInts[i]);
			}
			return results;
		}
	}
	public static int[] getIntsAttribute(Element e, String name) {
		return getIntsAttribute(e, name, null);
	}

	public static String[] getStringsAttribute(Element e, String name, String[] def) {
		String encoded = getAttribute(e, name, null);
		if (encoded == null) return def;
		if (encoded.length() == 0) return new String[0];
		else return encoded.split(",");
	}
	
	public static String[] getStringsAttribute(Element e, String name) {
		return getStringsAttribute(e, name, null);
	}
	
	public static Date getTimestampAttribute(Element e, String name, Date def) {
		Long timestamp = getLongAttribute(e, name, null);
		if (timestamp == null) return def;
		else return new Date(timestamp);
	}
	
	public static Date getTimestampAttribute(Element e, String name) {
		return getTimestampAttribute(e, name, null);
	}
	
	public static Boolean getBooleanAttribute(Element e, String name, Boolean def) {
		String encoded = getAttribute(e, name, null);
		if (encoded == null) return def;
		else return new Boolean(encoded);
	}
	
	public static Boolean getBooleanAttribute(Element e, String name) {
		return getBooleanAttribute(e, name, null);
	}
	
	public static Double getDoubleAttribute(Element e, String name, Double def) {
		String encoded = getAttribute(e, name, null);
		if (encoded == null) return def;
		else return new Double(encoded);
	}
	
	public static Double getDoubleAttribute(Element e, String name) {
		return getDoubleAttribute(e, name, null);
	}
	
	public static Long getLongAttribute(Element e, String name, Long def) {
		String encoded = getAttribute(e, name, null);
		if (encoded == null) return def;
		else return new Long(encoded);
	}
	
	public static Long getLongAttribute(Element e, String name) {
		return getLongAttribute(e, name, null);
	}
	
	public static Integer getIntegerAttribute(Element e, String name, Integer def) {
		String encoded = getAttribute(e, name, null);
		if (encoded == null) return def;
		else return new Integer(encoded);
	}
	
	public static Integer getIntegerAttribute(Element e, String name) {
		return getIntegerAttribute(e, name, null);
	}
	
	public static String getAttribute(Element e, String name) {
		return getAttribute(e, name, null);
	}
	
	public static String getAttribute(Element e, String name, String def) {
		if (e.getAttributeNode(name)==null) return def;
		return e.getAttribute(name);
	}
	
	public static Document loadXML(File f) throws SAXException, IOException, ParserConfigurationException {
		return loadXML( f.toURL() );
	}
	
	public static Document loadXML(URL u) throws SAXException, IOException, ParserConfigurationException {
		InputStream in = null;
		try {
			in = u.openStream();
			
			Document doc = loadXML(in);
			doc.setDocumentURI(u.toExternalForm());
			
			return doc;
		}
		finally {
			if (in != null) in.close();
		}
	}
	
	public static Document createDocument() throws ParserConfigurationException {
		synchronized (poxBuilderFactory) {
			if (poxBuilder==null) poxBuilder = poxBuilderFactory.newDocumentBuilder();
			return poxBuilder.newDocument();
		}
	}

	public static Document loadXML(InputSource in, boolean useSharedBuilder) throws SAXException, IOException, ParserConfigurationException {
		if (useSharedBuilder) {
			synchronized (poxBuilderFactory) {
				if (poxBuilder==null) poxBuilder = poxBuilderFactory.newDocumentBuilder();
				
				Document doc = poxBuilder.parse(in); //NOTE: poxBuilder is not thread safe, parse in sync block
				return doc;
			}
		}
		else {
			DocumentBuilder builder;
			synchronized (poxBuilderFactory) {
				builder = poxBuilderFactory.newDocumentBuilder();
			}

			Document doc = builder.parse(in); //NOTE: builder is safe bacause local, so parse outside sync block
			return doc;
		}
	}
	
	public static Document loadXML(InputStream in) throws SAXException, IOException, ParserConfigurationException {
		return loadXML(new InputSource(in), false);
	}

	public static Document parseXML(String s) throws SAXException, ParserConfigurationException {
		try {
			return loadXML(new InputSource(new StringReader(s)), false);
		} catch (Exception e) {
			throw new RuntimeException("IOException while reading from string reader (strange)", e);
		}
	}
	
	/*
	public static void xmlToTable(String xml, DefaultTableModel model, String tag, String[] attribs) throws SAXException, ParserConfigurationException {
		Document doc= parseXML( xml );
		
		NodeList elements = doc.getElementsByTagName(tag);
		
		for (int i = 0; i < elements.getLength(); i++) {
			Element element = (Element)elements.item(i);

			String[] v= new String[attribs.length];
			for (int j = 0; j < attribs.length; j++) {
				v[j]= element.getAttribute( attribs[j] );
			}
			
			model.addRow(v);
		}
	}

	public static void xmlToMap(String xml, Map m, String tag, String keyAttr, String valueAttr) throws SAXException, ParserConfigurationException {
		Document doc= parseXML( xml );
		
		NodeList elements = doc.getElementsByTagName(tag);
		
		for (int i = 0; i < elements.getLength(); i++) {
			Element element = (Element)elements.item(i);

			String k = element.getAttribute(keyAttr);
			String v = valueAttr.equals("*") ? element.getTextContent() : element.getAttribute(valueAttr);
			
			m.put(k, v);
		}
	}*/

	public static Element getElement(Element element, String path) throws XPathExpressionException {
		Node n = (Node)xpath.evaluate(path, element, XPathConstants.NODE);
		if (n==null) return null;
		
		if (!(n instanceof Node)) throw new IllegalArgumentException("xpath "+path+" does not designate an XML element (tag)");
		else return (Element)n;
	}

	public static Element getFirstElement(Element parent, String tagName, String attributeName, String attributeValue) {
		Node childNode = parent.getFirstChild();
		while (childNode != null) {
			if (childNode instanceof Element && 
				childNode.getNodeName().equals(tagName)) {
				Element childElement = (Element)childNode;
				if (childElement.getAttribute(attributeName).equals(attributeValue))
					return childElement;
			}
			childNode = childNode.getNextSibling();
		}
		return null;
	}
	
	public static Element getFirstElement(Element parent, String tagName) {
		Node childNode = parent.getFirstChild();
		while (childNode != null) {
			if (childNode instanceof Element && 
				(tagName == null || childNode.getNodeName().equals(tagName)))
				return (Element)childNode;
			childNode = childNode.getNextSibling();
		}
		return null;
	}
	
	public static Element getNextElement(Element predecessor, String tagName) {
		Node node = predecessor.getNextSibling();
		while (node != null) {
			if (node instanceof Element && 
					(tagName == null || node.getNodeName().equals(tagName)))
				return (Element)node;
			node = node.getNextSibling();
		}
		return null;
	}
	
	public static class ElementIterator implements Iterator<Element> {

		private Element element;
		private String tagName;
		
		public ElementIterator(Element parent, String tagName) {
			this.tagName = tagName;
			this.element = XmlUtil.getFirstElement(parent, tagName);
		}
		
		public boolean hasNext() {
			return element != null;
		}

		public Element next() {
			Element retVal = element;
			element = XmlUtil.getNextElement(element, tagName);
			return retVal;
		}

		public void remove() {
			throw new UnsupportedOperationException();
		}
		
	}
	
	public static Iterator<Element> getElementIterator(Element parent, String tagName) {
		return new ElementIterator(parent, tagName);
	}
	
	public static Element getElementByPath(Element parent, String path) {
		String tagNames[] = pathSplitPattern.split(path);
		Element element = parent;
		for (String tagName: tagNames) {
			element = getFirstElement(element, tagName);
			if (element == null) break;
		}
		
		return element;
	}
	
	public static Element getElementByPath(Element parent, String path, String attributeName, String attributeValue) {
		String tagNames[] = pathSplitPattern.split(path);
		Element element = parent;
		int index = 0;
		for (String tagName: tagNames) {
			boolean last = index == tagNames.length - 1;
			if (last) element = getFirstElement(element, tagName, attributeName, attributeValue);
			else element = getFirstElement(element, tagName);
			if (element == null) break;
			index++;
		}
		
		return element;
	}

	/**
	 * Apply xinclude processing; see http://www.w3.org/TR/xinclude/
	 * @throws Exception 
	 */
	public static void xinclude(Node n) throws Exception { //FIXME: specific exceptions!
		if (n instanceof Element) {
			Element e = (Element)n;
			String tag = e.getTagName();
			
			if (tag.equals("xi:include")) {
				performeXInclude(e);
				return; //NOTE: e is gone, recursive processing must happen inside performeXInclude()
			}
		}

		Node m = n.getFirstChild();
		
		while (m!=null) {
			Node next = m.getNextSibling(); //NOTE: remember reference; if xinclude is performed, siblings are lost
			
			if (m instanceof Element) xinclude(m);
			m = next;
		}		
	}
	
	protected static void performeXInclude(Element e) throws Exception {
		String href = getAttribute(e, "href");
		if (href==null) throw new Exception("missing href attribute in <xi:include> element"); //FIXME: specific exception!
		
		String parse = getAttribute(e, "parse", "xml");
		String encoding = getAttribute(e, "encoding");

		String xpointer = getAttribute(e, "xpointer");
		
		NodeList f = e.getElementsByTagName("xi:fallback");
		Element fallback = f.getLength() > 0 ? (Element)f.item(0) : null;
		
		String base = e.getBaseURI();
		URL u = base == null ? new URL(href) : new URL(new URL(base), href);
		
		try {
			if (parse.equals("xml")) {
				Document idoc = loadXML(u); //TODO: use encoding?
				Node n = idoc.getDocumentElement();
				
				if (xpointer!=null) {
					//TODO: xpointers define ranges - a plain xpath falls short.
					NodeList nodes = (NodeList)xpath.evaluate(xpointer, n, XPathConstants.NODESET); 

					if (nodes.getLength()==0) { //TODO: allow empty nodeset?...
						//TODO: is this a resource exception (should trigger fallback?)
						//FIXME: specific exception!
						throw new Exception("failed to resolve xpath "+xpointer+" in document "+href); 
					}
					
					Node p = e.getParentNode();
					Node at = e.getNextSibling();

					for(int i=0; i < nodes.getLength(); i++) {
						n = nodes.item(i);
						n = e.getOwnerDocument().importNode(n, true);
						
						p.insertBefore(n, at);
					}

					p.removeChild(e);
				}
				else {
					xinclude(n); //NOTE: apply xi-processing recusively on loaded document!
					
					n = e.getOwnerDocument().importNode(n, true);
					e.getParentNode().replaceChild(n, e); //TODO: include leading/trailing PE/Comments/etc?!
				}
			}
			else {
				String text = encoding == null ? IOUtil.slurp(u) : IOUtil.slurp(u.openStream(), encoding);
				Document doc = e.getOwnerDocument();
				Node n = doc.createTextNode(text);
				
				e.getParentNode().replaceChild(n, e);
			}
		}
		catch (IOException ex) {
			if (fallback==null) throw ex;
			
			Node p = e.getParentNode();
			Node at = e.getNextSibling();
			Node n = fallback.getFirstChild();
			
			while (n!=null) {
				Node next = n.getNextSibling();
				
				n.getParentNode().removeChild(n);
				p.insertBefore(n, at);
				n = next;
			}
			
			p.removeChild(e);
		}
	}
	
	public static String toString(Node node) throws TransformerException {
		StringWriter writer = new StringWriter();
		writeXml(node, writer, "UTF-8"); //or use UTF-16?... //TODO: terse, no indent...
		return writer.getBuffer().toString();
	}
	
	public static void writeXml(Node node, File f, String encoding) throws IOException, TransformerException {
		OutputStream out = new FileOutputStream(f);
		writeXml(node, out, encoding);
		out.close();
	}

	public static void writeXml(Node node, OutputStream out, String encoding) throws TransformerException {
		 try {
			Writer wr = new OutputStreamWriter(out, encoding);
			 writeXml(node, wr, encoding);
			 
			 wr.flush();
			//if (out!=System.out && out!=System.err) wr.close();			
		} catch (IOException e) {
			throw new TransformerException(e);
		}
	}
	
	public static void writeXml(Node node, Writer out, String encoding) throws TransformerException {
		 DOMSource domSource = new DOMSource(node);
		 StreamResult streamResult = new StreamResult(out);
		 
		 //TODO: optionally use shared serializer...
		 Transformer serializer; 
		 synchronized (poxTranformerFactory) {
			 serializer = poxTranformerFactory.newTransformer();
		 }
		 
		 //TODO: take properties from parameter
		 serializer.setOutputProperty(OutputKeys.ENCODING, encoding);
		 serializer.setOutputProperty(OutputKeys.INDENT,"yes");
		 
		 serializer.transform(domSource, streamResult); 
	}
	
	public static void main(String[] args) throws Exception {
		File f = new File(args[0]);
		Document doc = loadXML(f);
		xinclude(doc);
		writeXml(doc, System.out, "utf-8");
	}

	public static String escape(String s) {
		/*
		//NOTE: quite slow
		s = s.replaceAll("&","&amp;");
		s = s.replaceAll("<","&lt;");
		s = s.replaceAll(">","&gt;");
		s = s.replaceAll("\"","&quot;");
		s = s.replaceAll("'","&apos;");
		return s;
		*/
		
		int c = s.length();
		if (c==0) return s;
		
		char[] chars = new char[c];
		s.getChars(0, c, chars, 0);

		StringBuilder b = new StringBuilder(c*2);
		
		for (int i = 0; i < chars.length; i++) {
			char ch = chars[i];
			
			switch (ch) {
				case '&': b.append("&amp;"); break;
				case '<': b.append("&lt;"); break;
				case '>': b.append("&gt;"); break;
				case '"': b.append("&quot;"); break;
				case '\'': b.append("&apos;"); break;
				default: b.append(ch);
			}
		}
		
		return b.toString();
	}
	
	public static List<String> getAttributeNames(Element element) {
		NamedNodeMap attribs = element.getAttributes();
		int c = attribs.getLength();
		List<String> names = new ArrayList<String>(c);

		for (int i=0; i<c; i++) {
			Attr attr = (Attr)attribs.item(i);
			String attrib = attr.getName();
			names.add(attrib);
		}
		
		return names;
	}
	
	public static Map<String, String> getAttributes(Element element) {
		NamedNodeMap attribs = element.getAttributes();
		int c = attribs.getLength();
		
		Map<String, String> values = new HashMap<String, String>(c);

		for (int i=0; i<c; i++) {
			Attr attr = (Attr)attribs.item(i);
			String attrib = attr.getName();
			String value = attr.getValue();
			
			values.put(attrib, value);
		}
		
		return values;
	}
	
}
