package de.brightbyte.net;

import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import de.brightbyte.io.MultiStreamBuffer;
import de.brightbyte.io.MultiStreamBuffer.View;
import de.brightbyte.io.MultiStreamBuffer.ViewMonitor;

public class TcpDispatcher {
	
	public static interface SkipHandler {
		public ByteBuffer makeSkipMessage(int skippedEntryCount);
	}
	
	protected Set<Target> targets = new HashSet<Target>();
	protected Selector selector ;
	protected ServerSocketChannel serverSocket;
	protected List<Runnable> tasks = new ArrayList<Runnable>();
	
	protected int skipKillThreshold = 1024;
	protected SkipHandler skipHandler = null;
	
	protected MultiStreamBuffer<ByteBuffer> buffer;
	
	protected class Target {
		protected SocketChannel channel;
		protected MultiStreamBuffer.View<ByteBuffer> queue;
		protected SelectionKey key;
		protected int dropedChunks = 0;
		protected volatile boolean registering = false;
		
		private ViewMonitor<ByteBuffer> monitor = new ViewMonitor<ByteBuffer>() {
		
			public void skippedElement(View<ByteBuffer> view) {
				if (view.getSkippedEntryCount() > skipKillThreshold) {
					try {
						//TODO: log
						close();
					} catch (IOException exx) {
						/* ignore */
					}
				}
			}
		
			public void removedElement(View<ByteBuffer> view) {
				// noop
			}
		
			public void becameEmpty(View<ByteBuffer> view) {
				unregister();
			}
		
			public void becameAvailable(View<ByteBuffer> view) {
				try {
					register();
				} catch (ClosedChannelException ex) {
					try {
						close();
					} catch (IOException exx) {
						/* ignore */
					}
				}
			}

			public void becameClosed(View<ByteBuffer> view) {
				try {
					close();
				} catch (IOException ex) {
					/* ignore */
				}
			}
		
		};
		
		public Target(SocketChannel ch) {
			queue = buffer.view(false);
			queue.setMonitor(monitor );
			channel = ch;
		}

		/*
		public synchronized void queueChunk(ByteBuffer data) throws ClosedChannelException {
			//if target-queue is full, drop chunk
			if (queue.size()>=queueLimit) {
				dropedChunks ++;
				//TODO: optionally kill connection on overflow, or on dropedChunks > someLimit
				return;
			}

			//TODO: if droppedChunks > 0, optionally queue special message and then reset droppedChunks
			
			ByteBuffer buff = data.duplicate();

			queue.add(buff);
			//System.out.println("push: "+queue.size()+" queued; key = "+key);
			if (key == null) {
				register();
			}
		}
		*/
		
		protected void register() throws ClosedChannelException {
			
			synchronized (Target.this) {
				if (key!=null && key.isValid()) return;
				if (registering) return;
				registering = true;
			}

			Runnable r = new Runnable() {
				public void run() {
					try {
						SelectionKey key = channel.register(selector, SelectionKey.OP_WRITE, Target.this);
						synchronized (Target.this) {
							if (key.isValid()) Target.this.key = key;
						}
					} catch (ClosedChannelException e) {
						/* ignore */
					} catch (CancelledKeyException e) {
						/* ignore */
					} finally {
						registering = false;
					}
				}
			}; 
			
			queueTask(r);
			
			/*
			final FutureTask<SelectionKey> future = new FutureTask<SelectionKey>(new Callable<SelectionKey>() {
				public SelectionKey call() throws Exception {
					return channel.register(selector, SelectionKey.OP_WRITE, Target.this);
				}
			}); 
			
			queueTask(future);
			
			try {
				key = future.get();
				//System.out.println("register: got key = "+key);
			} catch (InterruptedException e) {
				Thread.currentThread().interrupt();
			} catch (ExecutionException e) {
				Throwable th = e.getCause();
				if (th instanceof Error) throw (Error)th;
				else if (th instanceof RuntimeException) throw (RuntimeException)th;
				else if (th instanceof ClosedChannelException) throw (ClosedChannelException)th;
				else throw new Error("Unexpected Exception", th);
			} 
			*/
		}
		
		protected void unregister() {
			synchronized (this) {
				if (key!=null) {
					//System.out.println("queue empty, but key was not canceled!");
					key.cancel(); //XXX: shouldn't happen!
					key = null;
				}
			}
		}

		/*
		public synchronized ByteBuffer popChunk() {
			if (queue.isEmpty()) {
				if (key!=null) {
					//System.out.println("queue empty, but key was not canceled!");
					key.cancel(); //XXX: shouldn't happen!
					key = null;
				}

				return null;
			}
			
			ByteBuffer data = queue.remove(0);
			//System.out.println("pop: "+queue.size()+" remaining; key = "+key);

			if (key != null && queue.isEmpty()) {
				key.cancel();
				key = null;
				//System.out.println("queue now empty, key canceled");
			}
			
			return data;
		}
		*/
		
		public boolean isIdle() {
			return queue.isEmpty();
		}
		
		public boolean isSkipping() {
			return queue.getSkippedEntryCount()>0;
		}
		
		public void sendChunk() throws IOException {
			ByteBuffer data = null;
			
			if (queue.getSkippedEntryCount()>0) {
				if (skipHandler!=null) data = skipHandler.makeSkipMessage(queue.getSkippedEntryCount()); 
				queue.resetSkippedEntryCount();
			}
			
			if (data == null) {
				data = queue.poll();
				if (data==null) {
					unregister(); //should not happen, but play safe
					return;
				}
			}
			
			data.rewind(); //NOTE: using shared ByteBuffer, all sending happens in the same thread
			//System.out.println("WRITE: "+TcpDispatcher.toString(data));
			channel.write(data);
		}

		public void close() throws IOException {
			unregister();
			
			synchronized (TcpDispatcher.this) {
				targets.remove(this);
			}
			
			queue.close();
			channel.close();			
		}

		public boolean isClosed() {
			return queue.isClosed();
		}
	}
	
	public TcpDispatcher(int bufferSize) {
		buffer = new MultiStreamBuffer<ByteBuffer>(bufferSize);
		buffer.setAllowSkip(true);
	}
		
	protected static String toString(ByteBuffer data) {
		data.mark();
		byte[] b = new byte[data.remaining()];
		data.get(b);
		data.reset();
		
		return new String(b); //TODO: encoding; optionally hex-dump?
	}
	
	public void bind(SocketAddress address) throws IOException {
		initSelector();
		
		serverSocket = ServerSocketChannel.open();
		serverSocket.socket().bind(address);
		serverSocket.configureBlocking(false);
		serverSocket.register(selector, SelectionKey.OP_ACCEPT);
	}
	
	public void addConnection(Socket sock) throws IOException {
		addConnection(sock.getChannel());
	}
	
	protected synchronized void initSelector() throws IOException {
		if (selector == null) selector = Selector.open();
	}
	
	public synchronized void addConnection(SocketChannel ch) throws IOException {
		initSelector();
		
		ch.configureBlocking(false);
		Target t = new Target(ch);
		targets.add(t);
	}
	
	protected synchronized void queueTask(Runnable task) {
		tasks.add(task);
		if (selector!=null) selector.wakeup();
	}
	
	public void run() throws IOException {
		initSelector();
		
		while (true) {
			if (!tasks.isEmpty()) {
				Runnable[] tt;
				synchronized (this) {
					tt = (Runnable[]) tasks.toArray(new Runnable[tasks.size()]);
				}
				
				for( Runnable task : tt ) {
					task.run();
				}
				
				tasks.clear();
			}
			
			selector.select(); 
		    //if (n <= 0) continue;
		    
			for (SelectionKey key : selector.selectedKeys()) {
				try {
					if (key.isAcceptable()) {
						SocketChannel ch = ((ServerSocketChannel)key.channel()).accept();
						if (ch!=null) addConnection(ch);
						continue;
					}
	
					Target tgt = (Target)key.attachment();
					
					if (key.isWritable()) {
						try {
							tgt.sendChunk();
						}
						catch (IOException ex) {
							//channel is dead, we don't really care why. do some cleanup
							
							synchronized (this) {
								targets.remove(tgt);
							}
							
							tgt.close();
						}
					}
				}
				catch (CancelledKeyException ex) {
					/* ignore, key gone */
				}
				catch (RuntimeException ex) { //don't kill main loop
					reportRuntimeException(ex, key);
				}
			}
		}
	}
	
	protected void reportRuntimeException(RuntimeException ex, SelectionKey key) {
		ex.printStackTrace(); //TODO: logger
	}

	public void send(ByteBuffer data) {
		ByteBuffer buff = ByteBuffer.allocateDirect(data.remaining()); 
		buff.put(data);
		buff.flip();
		
		buffer.add(buff);
		
		/*Iterator<Target> it = targets.iterator(); 
		while (it.hasNext()) {
			Target t = it.next();
			try {
				t.queueChunk(data);
			} catch (ClosedChannelException ex) {
				ex.printStackTrace(); //FIXME
				
				it.remove();
				try {
					t.close();
				} catch (IOException exx) { / *ignore * / }
			}
		}
		*/
	}
	
	public void close() throws IOException {
		synchronized (this) {
			if (worker!=null) worker.interrupt();
			
			if (selector!=null) selector.close();
			if (serverSocket!=null) serverSocket.close();
			
			serverSocket = null;
			selector = null;
		}
	}
	
	public OutputStream getStream() {
		return new OutputStream() {
			protected int chunkSize = 1024;
			protected ByteBuffer buffer = ByteBuffer.allocateDirect(chunkSize);
			
			protected void prepare(int length) throws IOException {
				if (buffer!=null && buffer.position() + length > chunkSize) {
					flush();
				}
			}
			
			@Override
			public void write(int b) throws IOException {
				prepare(1);
				buffer.put((byte)b);
			}
		
			@Override
			public void write(byte[] b, int off, int len) throws IOException {
				if (len>b.length-off) len = b.length-off;
				if (len<=0) return;
				
				prepare(len);
				buffer.put(b, off, len);
			}
			
			@Override
			public void write(byte[] b) throws IOException {
				write(b, 0, b.length);
			}
		
			@Override
			public void flush() throws IOException {
				if (buffer!=null && buffer.position()>0) {
					buffer.flip();
					send(buffer);
					buffer.clear();
				}
			}
		
			@Override
			public void close() throws IOException {
				flush();
				TcpDispatcher.this.close();
			}
		
		};
	}

	public int getSkipKillThreshold() {
		return skipKillThreshold;
	}

	public void setSkipKillThreshold(int skipKillThreshold) {
		this.skipKillThreshold = skipKillThreshold;
	}
		
	private volatile boolean running = false; 
	protected Thread worker;
	
	public synchronized boolean isRunning() {
		return running;
	}
	
	public synchronized boolean isBound() {
		return serverSocket != null && serverSocket.socket().isBound() && !serverSocket.socket().isClosed();
	}
	
	public synchronized void start() {
		if (isRunning()) throw new IllegalStateException("already running");
		if (!isBound()) throw new IllegalStateException("not bound");
		
		worker = new Thread("TcpDispatcher#"+hashCode()) {
			public void run() {
				try {
					TcpDispatcher.this.run();
				} catch (IOException ex) {
					ex.printStackTrace(); //FIXME
				}
			}
		};
		
		worker.start();
	}

	public SkipHandler getSkipHandler() {
		return skipHandler;
	}

	public void setSkipHandler(SkipHandler skipHandler) {
		this.skipHandler = skipHandler;
	}
	
	public synchronized void dumpStats(PrintStream out) {
		int total = 0;
		int skipping = 0;
		int idle = 0;
		
		for (Target t: targets) {
			if (t.isClosed()) continue;
			
			total++;
			if (t.isSkipping()) skipping ++;
			if (t.isIdle()) idle ++;
		}
		
		out.println("Queue size: "+buffer.size());
		out.println("Connections: "+total+", "+idle+" idle, "+skipping+" skipping");
	}
	
	public static void main(String[] args) throws IOException, InterruptedException {
		int port = Integer.parseInt(args[0]);
		SocketAddress a = new InetSocketAddress(port);
		
		int delay = 30;
		int size = 5;
		int skipKill = 20;
		int statsInterval = 10;
		
		TcpDispatcher dispatcher = new TcpDispatcher(size);
		dispatcher.setSkipKillThreshold(skipKill);
		dispatcher.setSkipHandler(new SkipHandler() {
			protected ByteBuffer buffer = ByteBuffer.allocateDirect(100); 
			public ByteBuffer makeSkipMessage(int skippedEntryCount) {
				String s = "*** skippied "+skippedEntryCount;
				
				buffer.clear();
				buffer.put(s.getBytes());
				buffer.flip();
				return buffer;
			}
		});
		
		dispatcher.bind(a);		
		dispatcher.start();
		
		PrintStream out = new PrintStream(dispatcher.getStream(), true);
		
		long i = 0;
		while (true) {
			if (delay>0) Thread.sleep((int)(Math.random()*delay+1));
			
			long d = (long)(Math.random()*Long.MAX_VALUE);
			String s = Long.toString(d, 16);
			
			System.out.println(s);
			out.println(s);
			
			i++;
			if ((i % statsInterval)==0) {
				dispatcher.dumpStats(System.out);
			}
		}
		
		/*
		BufferedReader in = new BufferedReader( new InputStreamReader(System.in));
		String s;
		while ((s = in.readLine()) != null) {
			out.println(s);
		}
		*/
	}
}
