/*
 * select.c
 * Copyright (c) 1995,1996,1997 Be, Inc.	All Rights Reserved 
 *
 * Handling select()ing on both the net and the tty simultaneously.  
 * Warning: This code is very tricky.
 *
 * One day, select() will just work on all types of descriptors and this
 * code won't be necessary.
 */
#include <stdio.h>
#include <OS.h>
#include <socket.h>
#include <fcntl.h>
#include <stdarg.h>
#include <signal.h>
#include <errno.h>
#include <netdebug.h>
#include <socket_private.h>

#undef select
#undef ioctl

#define NTTY 4
#define STEP 50000

typedef struct select_parms {
	int fd;
	int *fds;
	int extra;
	int mask;
	int *masks;
	int ready; 
	int result;
	thread_id id;
	struct timeval *timeout;
	sem_id eventsem;
	sem_id running;
	sem_id donesem;
} select_parms;

#define fork_thread(f, a) _fork_thread(f, #f, a)

static struct timeval zero = { 0, 0 };

static long
_fork_thread(
			 thread_entry func,
			 char *name,
			 void *args
			 )
{
	thread_id id;

	dprintf("forking thread %s\n", name);
	resume_thread(id = spawn_thread(func, name, B_NORMAL_PRIORITY, args));
	return (id);
}


void
sigint(int ignored)
{
	signal(SIGINT, sigint);
}

static long
tty_select_thread(
				  void *arg
				  )
{
	select_parms *parms = (select_parms *)arg;
	int n;
	int ret;

	signal(SIGINT, sigint);
	release_sem(parms->running);
	ret = tty_select(1, &parms->fd, &parms->mask, parms->timeout);
	dprintf("tty select returns %d, %d\n", ret, parms->mask);
	parms->result = ret;
	parms->ready = 1;
	release_sem(parms->eventsem);
	release_sem(parms->donesem);
	return (0);
}

static long
tty_mult_select_thread(
					   void *arg
					   )
{
	select_parms *parms = (select_parms *)arg;
	int n;
	int ret;

	signal(SIGINT, sigint);
	release_sem(parms->running);
	ret = tty_select(parms->fd, parms->fds, parms->masks, parms->timeout);
	dprintf("tty select (mult) returns %d\n", ret);
	parms->result = ret;
	parms->ready = 1;
	release_sem(parms->eventsem);
	release_sem(parms->donesem);
	return (0);
}

static long
net_select_thread(
				  void *arg
				  )
{
	select_parms *parms = (select_parms *)arg;
	struct fd_set rfds;
	struct fd_set wfds;
	int ret;

	signal(SIGINT, SIG_IGN);
	release_sem(parms->running);
	FD_ZERO(&rfds);
	if (parms->mask & 1) {
		FD_SET(parms->fd, &rfds);
	}
	FD_ZERO(&wfds);
	if (parms->mask & 2) {
		FD_SET(parms->fd, &wfds);
	}
	ret = select(parms->fd + 1, &rfds, &wfds, NULL, parms->timeout);
	dprintf("net select returns %d\n", ret);
	parms->mask = 0;
	if (FD_ISSET(parms->fd, &rfds)) {
		parms->mask |= 1;
	}
	if (FD_ISSET(parms->fd, &wfds)) {
		parms->mask |= 2;
	}
	parms->result = ret;
	parms->ready = 1;
	release_sem(parms->eventsem);
	release_sem(parms->donesem);
	return (0);
}


int
tty_select_multiple(int ntty, int *ttyfd, int *ttyflags, 
					struct timeval *timeout)
{
	sem_id eventsem;
	select_parms parms[10];
	int nparms = 0;
	long status;
	int i;

	eventsem = create_sem(0, "select event");

	for (i = 0; i < ntty; i++) {
		if (ttyflags[i] & 1) {
			parms[nparms].fd = ttyfd[i];
			parms[nparms].extra = i;
			parms[nparms].mask = 1;
			parms[nparms].ready = 0;
			parms[nparms].timeout = timeout;
			parms[nparms].eventsem = eventsem;
			parms[nparms].donesem = create_sem(0, "ttyrw done");
			parms[nparms].running = create_sem(0, "ttyrw running");
			parms[nparms].id = fork_thread(tty_select_thread, 
										   &parms[nparms]);
			acquire_sem(parms[nparms].running);
			delete_sem(parms[nparms].running);
			nparms++;
		}
		if (ttyflags[i] & 2) {
			parms[nparms].fd = ttyfd[i];
			parms[nparms].extra = i;
			parms[nparms].mask = 2;
			parms[nparms].ready = 0;
			parms[nparms].timeout = timeout;
			parms[nparms].eventsem = eventsem;
			parms[nparms].donesem = create_sem(0, "ttyrw done");
			parms[nparms].running = create_sem(0, "ttyrw running");
			parms[nparms].id = fork_thread(tty_select_thread, 
										   &parms[nparms]);
			acquire_sem(parms[nparms].running);
			delete_sem(parms[nparms].running);
			nparms++;
		}
	}
	status = acquire_sem(eventsem);
	if (status < B_NO_ERROR) {
		dprintf("select net tty event: %s\n", strerror(status));
	}
	
	for (i = 0; i < nparms; i++) {
		if (!parms[i].ready) {
			dprintf("tty1 not ready: interrupt it\n");
			for (;;) {
				kill(parms[i].id, SIGINT);
				status = acquire_sem_etc(parms[i].donesem, 1, B_TIMEOUT, 
									 1000000);
				if (status != B_TIMED_OUT) {
					break;
				}
				dprintf("tty thread didn't die, killing again\n");	
			}
		}
		delete_sem(parms[i].donesem);
		wait_for_thread(parms[i].id, &status);
	}
	delete_sem(eventsem);
	for (i = 0; i < nparms; i++) {
		if (parms[i].ready && parms[i].result > 0) {
			memset(ttyflags, 0, sizeof(ttyflags[0]) * ntty);
			ttyflags[parms[i].extra] = parms[i].mask;
			return (1);
		}
	}
	return (-1);
}

int
select_net_tty_onepass(
					   int net,
					   int *netflags,
					   int ntty,
					   int *tty,
					   int *ttyflags,
					   struct timeval *timeout
					   )
{
	sem_id eventsem;
	select_parms ttyparms;
	select_parms netparms;
	thread_id netid;
	thread_id ttyid;
	long status;
	int count;

	eventsem = create_sem(0, "select event");
	netparms.fd = net;
	netparms.mask = *netflags;
	netparms.ready = 0;
	netparms.timeout = timeout;
	netparms.eventsem = eventsem;
	netparms.donesem = create_sem(0, "net done");
	netparms.running= create_sem(0, "net running");
	netid = fork_thread(net_select_thread, &netparms);

	if (ntty > 1 || *ttyflags == 3) {
		ttyparms.fd = ntty;
		ttyparms.fds = tty;
		ttyparms.masks = ttyflags;
		ttyparms.ready = 0;
		ttyparms.timeout = timeout;
		ttyparms.eventsem = eventsem;
		ttyparms.donesem = create_sem(0, "tty done");
		ttyparms.running = create_sem(0, "tty running");
		ttyid = fork_thread(tty_mult_select_thread, &ttyparms);
	} else {
		ttyparms.fd = tty[0];
		ttyparms.mask = ttyflags[0];
		ttyparms.ready = 0;
		ttyparms.timeout = timeout;
		ttyparms.eventsem = eventsem;
		ttyparms.donesem = create_sem(0, "tty done");
		ttyparms.running = create_sem(0, "tty running");
		ttyid = fork_thread(tty_select_thread, &ttyparms);
	}

	acquire_sem(netparms.running);
	acquire_sem(ttyparms.running);
	delete_sem(netparms.running);
	delete_sem(ttyparms.running);

	status = acquire_sem(eventsem);
	if (status < B_NO_ERROR) {
		dprintf("select net tty event: %s\n", strerror(status));
	}

	if (!netparms.ready) {
		dprintf("net not ready: interrupt it\n");
		for (;;) {
			_socket_interrupt(net);
			status = acquire_sem_etc(netparms.donesem, 1, B_TIMEOUT, 
									 1000000);
			if (status != B_TIMED_OUT) {
				break;
			}
			dprintf("net thread didn't die, killing again\n");
		}
	}
	delete_sem(netparms.donesem);
	wait_for_thread(netid, &status);

	if (!ttyparms.ready) {
		dprintf("tty not ready: interrupt it\n");
		for (;;) {
			kill(ttyid, SIGINT);
			status = acquire_sem_etc(ttyparms.donesem, 1, B_TIMEOUT, 
									 1000000);
			if (status != B_TIMED_OUT) {
				break;
			}
			dprintf("tty thread didn't die, killing again\n");	
		}
	}
	delete_sem(ttyparms.donesem);
	wait_for_thread(ttyid, &status);
	
	delete_sem(eventsem);
	
	*netflags = 0;
	*ttyflags = 0;
	count = 0;
	if (netparms.ready) {
		if (netparms.result > 0) {
			*netflags = netparms.mask;
			dprintf("added net event\n");
			count++;
		}
	}
	if (ttyparms.ready) {
		if (ttyparms.result > 0) {
			*ttyflags = ttyparms.mask;
			count += ttyparms.result;
			dprintf("added tty event\n");
		}
	}
	return (count);
}

int
select_net_tty(int net, int *netflags, int ntty, int *tty, int *ttyflags,
			   struct timeval *timeout)
{
	int ret;

	dprintf("Starting...\n");
	while ((ret = select_net_tty_onepass(net, netflags, ntty, tty, ttyflags,
										 timeout)) < 0) {
		dprintf("Trying again\n");
	}
	dprintf("...Finished (%d, %d, %d)\n", ret, *netflags, *ttyflags);
	return (ret);
}


int
tty_select(
		   int ntty,
		   int *ttyfdp,
		   int *ttyflagsp,
		   struct timeval *timeout
		   )
{
	int n = 1;
	int howmany;
	int cmd;
	int tryflags;
	int ttyflags = *ttyflagsp;
	int ttyfd = *ttyfdp;
	int i;
	bigtime_t d;
	bigtime_t useconds = 0;

	if (timeout != NULL) {
		if (timeout->tv_sec == 0 && timeout->tv_usec == 0) {
			n = 0;
		} else {
			useconds = (timeout->tv_sec * 1000000LL + timeout->tv_usec);
		}
	}
	if (ntty > 1 || (ttyflagsp[0] == 3)) {
		if (n > 0) {
			return (tty_select_multiple(ntty, ttyfdp, ttyflagsp, timeout));
		} else {
			for (i = 0; i < ntty; i++) {
				if (ttyflagsp[i] & 1) {
					tryflags = 1;
					n = 0;
					n = tty_select(1, &ttyfdp[i], &tryflags, &zero);
					if (n > 0) {
						memset(ttyflagsp, 0, sizeof(ttyflagsp[0]) * ntty);
						ttyflagsp[i] = tryflags;
						return (n);
					}
				}
				if (ttyflagsp[i] & 2) {
					tryflags = 2;
					n = tty_select(1, &ttyfdp[i], &tryflags, &zero);
					if (n > 0) {
						memset(ttyflagsp, 0, sizeof(ttyflagsp[0]) * ntty);
						ttyflagsp[i] = tryflags;
						return (n);
					}
				}
			}
			memset(ttyflagsp, 0, sizeof(ttyflagsp[0]) * ntty);
		}
		return (0);
	}
	if (ttyflags == 1) {
		cmd = 'ichr';
	} else {
		cmd = 'ochr';
	}
	if (useconds) {
		for (d = 0; d < useconds; d += STEP) {
			n = 0;
			dprintf("ioctl(%d, %08x, %d)...\n", ttyfd, cmd, n);
			howmany = ioctl(ttyfd, cmd, &n);
			dprintf("ioctl(%d, %08x) = %d, %d\n", ttyfd, cmd, howmany, n);
			if (howmany >= 0 && n > 0) {
				break;
			}
			snooze(STEP);
		}
	} else {
		dprintf("ioctl(%d, %08x, %d)...\n", ttyfd, cmd, n);
		howmany = ioctl(ttyfd, cmd, &n);
		dprintf("ioctl(%d, %08x) = %d, %d\n", ttyfd, cmd, howmany, n);
	}
	if (howmany < 0) {
		if (useconds == 0 && cmd == 'ichr') {
			/*
			 * Check for EOF
			 */
			n = 0;
			howmany = ioctl(ttyfd, cmd, &n);
			if (howmany < 0 || n > 0) {
				dprintf("EOF or data\n");
				return (1);
			}
		}
		/*
		 * Else, assume EINTR.
		 */
		*ttyflagsp = 0;
		return (0);
	}
	if (n == 0) {
		dprintf("clearing flags\n");
		*ttyflagsp = 0;
	}
	dprintf("tty_select: %d\n", !!n);
	return (!!n);
}

static void 
show(char *name,
	 int nbits,
	 struct fd_set *rbits,
	 struct fd_set *wbits,
	 struct fd_set *ebits,
	 struct timeval *timeout
	 )
{
	dprintf("%s select(%d): timeout = %08x %d %d, masks = %08x %08x %08x\n",
			name,
			nbits,
			timeout, (timeout ? timeout->tv_sec: -1),
			(timeout ? timeout->tv_usec: -1),
			(rbits ? rbits->mask[0] : 0xffffffff),
			(wbits ? wbits->mask[0] : 0xffffffff),
			(ebits ? ebits->mask[0] : 0xffffffff));

}

int
check_select(
			 int nbits,
			 struct fd_set *rbits,
			 struct fd_set *wbits,
			 struct fd_set *ebits,
			 struct timeval *timeout
			 )
{
	int netflags = 0;
	int ret;
	int ttyfd[NTTY];
	int ttyflags[NTTY];
	int ntty = 0;
	int netfd = -1;
	int i;


	show("Enter", nbits, rbits, wbits, ebits, timeout);
	if (ebits) {
		ebits->mask[0] = 0;
	}
	for (i = 0; i < nbits; i++) {
		if (rbits && (rbits->mask[i/32] & (1 << (i%32)))) {
			if (is_socket(i)) {
				netfd = i;
				netflags |= 1;
			} else {
				ttyfd[ntty] = i;
				ttyflags[ntty++] = 1;
			}
		}
		if (wbits && (wbits->mask[i/32] & (1 << (i%32)))) {
			if (is_socket(i)) {
				netfd = i;
				netflags |= 2;
			} else {
				if (ntty > 0 && ttyfd[ntty] == i) {
					ttyflags[ntty++] |= 2;
				} else {
					ttyfd[ntty] = i;
					ttyflags[ntty++] = 2;
				}
			}
		}
	}
	dprintf("ntty %d, netfd %d\n", ntty, netfd);
	if (ntty > 0) {
		int copy[NTTY];

		memcpy(copy, ttyflags, sizeof(ttyflags[0]) * ntty);
		ret = tty_select(ntty, ttyfd, copy, &zero);
		if (ret > 0) {
			if (rbits) {
				rbits->mask[0] = 0;
			}
			if (wbits) {
				wbits->mask[0] = 0;
			}
			for (i = 0; i < ntty; i++) {
				if (copy[i] & 1) {
					rbits->mask[0] |= (1 << ttyfd[i]);
				}
				if (copy[i] & 2) {
					wbits->mask[0] |= (1 << ttyfd[i]);
				}
			}
			show("Exit-ttypoll", ret, rbits, wbits, ebits, timeout);
			return (ret);
		}
	}

	if (netflags && ntty) {
		ret = select_net_tty(to_socket(netfd), 
							 &netflags, ntty, ttyfd, ttyflags, timeout);
		dprintf("select_net_tty returns\n");
		if (ret >= 0) {
			if (rbits) {
				rbits->mask[0] = 0;
				if (netflags & 1) {
					rbits->mask[0] |= (1 << netfd);
				}
				dprintf("setting rbits tty (%d)\n", ntty);
				for (i = 0; i < ntty; i++) {
					if (ttyflags[i] & 1) {
						rbits->mask[0] |= (1 << ttyfd[i]);
					}
				}
			}
			if (wbits) {
				wbits->mask[0] = 0;
				if (netflags & 2) {
					wbits->mask[0] |= (1 << netfd);
				}
				dprintf("setting wbits tty (%d)\n", ntty);
				for (i = 0; i < ntty; i++) {
					if (ttyflags[i] & 2) {
						wbits->mask[0] |= (1 << ttyfd[i]);
					}
				}
			}
		}
		show("Exit-all", ret, rbits, wbits, ebits, timeout);
		return (ret);
	}
	if (netflags) {
		if (rbits && rbits->mask[0]) {
			rbits->mask[0] &= ~netfd;
			rbits->mask[0] |= (1 << to_socket(netfd));
		}
		if (wbits && wbits->mask[0]) {
			wbits->mask[0] &= ~netfd;
			wbits->mask[0] |= (1 << to_socket(netfd));
		}
		ret = select(nbits, rbits, wbits, NULL, timeout);
		if (ret >= 0) {
			if (rbits && rbits->mask[0]) {
				rbits->mask[0] = 0;
				rbits->mask[0] &= ~to_socket(netfd);
				rbits->mask[0] |= (1 << netfd);
			}
			if (wbits && wbits->mask[0]) {
				wbits->mask[0] = 0;
				wbits->mask[0] &= ~to_socket(netfd);
				wbits->mask[0] |= (1 << netfd);
			}
		}
		show("Exit-net", ret, rbits, wbits, ebits, timeout);
		return (ret);
	}
	if (ntty > 0) {
		ret = tty_select(ntty, ttyfd, ttyflags, timeout);
		if (ret >= 0) {
			if (rbits) {
				rbits->mask[0] = 0;
			}
			if (wbits) {
				wbits->mask[0] = 0;
			}
			for (i = 0; i < ntty; i++) {
				if (ttyflags[i] & 1) {
					rbits->mask[0] |= (1 << ttyfd[i]);
				}
				if (ttyflags[i] & 2) {
					wbits->mask[0] |= (1 << ttyfd[i]);
				}
			}
		}
		show("Exit-tty", ret, rbits, wbits, ebits, timeout);
		return (ret);
	}
	show("Exit-nothing", 0, rbits, wbits, ebits, timeout);
	return (0);
}
