/*
 * dctpLprocessSApi.c -- Service-side API for Lprocess port of dccp-tp
 *
 * Copyright (C) 2008 Tom Phelan
 *
 * This file is part of dccp-tp.
 *
 * Dccp-tp is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 2.1 of the License, or
 * (at your option) any later version.
 *
 * Dccp-tp is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with dccp-tp.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Documentation and source code for dccp-tp is available at
 * http://www.phelan-4.com/dccp-tp/.
 */

#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <dirent.h>
#include <pthread.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include "dctpCore.h"
#include "dctpSupport.h"
#include "dctpRawApi.h"
#include "dctpLprocessPApi.h"
#include "dctpLprocessApi.h"

static void socketOpen(dctpPacket *pkt);
static void socketClose(dctpPacket *pkt);
static void socketBind(dctpPacket *pkt);
static void socketListen(dctpPacket *pkt);
static void socketConnect(dctpPacket *pkt);
static void socketAccept(dctpPacket *pkt);
static void socketSend(dctpPacket *pkt);
static void socketRecv(dctpPacket *pkt);
static void socketSetSockOpt(dctpPacket *pkt);

static void (*socketActions[DCTPLP_NUMCMDS])(dctpPacket *pkt) = {
    socketOpen,
    socketClose,
    socketBind,
    socketListen,
    socketConnect,
    socketAccept,
    socketSend,
    socketRecv,
    socketSetSockOpt
};

static int cmdsd;                    /* Descriptor for command socket */
static struct sockaddr_un caddr = {  /* Address for command socket */
    sun_family: AF_UNIX,
    sun_path:   DCTPLP_CMDSOCKNAME
};

static int openCmdSock(void) {
    DIR *d;
    struct dirent *de;
    char buf[NAME_MAX + 1];

    /* Get rid of old sockets in filesystem */
    if ((d = opendir(DCTPLP_SOCKDIRNAME)) != NULL) {
	while ((de = readdir(d)) != NULL) {
	    if (strcmp(de->d_name, ".") == 0) continue;
	    if (strcmp(de->d_name, "..") == 0) continue;
	    sprintf(buf, "%s/%s", DCTPLP_SOCKDIRNAME, de->d_name);
	    if (unlink(buf) < 0) {
		dctpoLog(DCTPLOG_ERR, "openCmdSock: can't delete %s: %s\n", buf, strerror(errno));
	    }
	}
	if (rmdir(DCTPLP_SOCKDIRNAME) < 0) {
	    dctpoLog(DCTPLOG_ERR, "openCmdSock: can't remove directory: %s\n", strerror(errno));
	}
    }
    /* Make a directory for sockets */
    if (mkdir(DCTPLP_SOCKDIRNAME, 0777) < 0) {
	dctpoLog(DCTPLOG_ERR, "openCmdSock: can't create socket directory: %s\n", strerror(errno));
    }
    /* Open the command socket */
    if ((cmdsd = socket(PF_UNIX, SOCK_DGRAM, 0)) < 0) {
	dctpoLog(DCTPLOG_ERR, "openCmdSock: can't open API command socket: %s\n", strerror(errno));
	return(-1);
    }
    if (bind(cmdsd, (struct sockaddr *)&caddr, sizeof(caddr)) < 0) {
	dctpoLog(DCTPLOG_ERR, "openCmdSock: can't bind API command socket: %s\n", strerror(errno));
	close(cmdsd);
	return(-1);
    }
    return(0);
}

static void writeRsp(void *rsp, int rlen, struct sockaddr_un *addr) {
    if (sendto(cmdsd, rsp, rlen, 0, (struct sockaddr *)addr,
	       sizeof(struct sockaddr_un)) < 0) {
	bpoint("writeRsp");
	dctpoLog(DCTPLOG_ERR, "writeRsp: error sending to %s: %s\n",
		 addr->sun_path, strerror(errno));
    }
}

static int createThread(void *(*func)(void *), dctpLpCmdThread *ct) {
    pthread_attr_t attr;

    pthread_attr_init(&attr);
    pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED);
    if (pthread_create(&(ct->thread), &attr, func, ct) != 0) {
	dctpoLog(DCTPLOG_ERR, "createThread: error creating API thread: %s\n",
		 strerror(errno));
	return(-1);
    }
    return(0);
}

void *cmdThread(void *arg) {
    dctpLpCmdThread *ct = (dctpLpCmdThread *)arg;
    dctpPacket *pkt;
    dctpLpGenCmd *cmd;
    struct sockaddr_un *raddr;
    dctpLpAckRsp rsp;

    pthread_mutex_lock(&(ct->mutex));
    while (1) {
	if (ct->first != NULL) {
	    pkt = ct->first;
	    if ((ct->first = pkt->nextPkt) == NULL) {
		ct->last = NULL;
	    }
	    pthread_mutex_unlock(&(ct->mutex));
	    cmd = (dctpLpGenCmd *)(pkt->appdata + sizeof(struct sockaddr_un));
	    if (cmd->cmd < DCTPLP_NUMCMDS) {
		socketActions[cmd->cmd](pkt);
	    } else {
		raddr = (struct sockaddr_un *)(pkt->appdata);
		rsp.cmd = DCTPLP_NACKRSP;
		rsp.rsperrno = EINVAL;
		writeRsp(&rsp, sizeof(rsp), raddr);
	    }
	    pthread_mutex_lock(&(ct->mutex));
	} else {
	    pthread_cond_wait(&(ct->cond), &(ct->mutex));
	}
    }
}

#define DCTPLP_CMDHASHSIZE  1019

dctpLpCmdThread *cmdThreads[DCTPLP_CMDHASHSIZE];

void dispatchCmd(dctpPacket *pkt, pid_t pid, int cmd) {
    dctpLpCmdThread *ct;
    uint_t hash = pid % DCTPLP_CMDHASHSIZE;
    dctpLpAckRsp rsp;

    for (ct = cmdThreads[hash]; ct != NULL; ct = ct->next) {
	if (ct->pid == pid) break;
    }
    if (ct == NULL) {
	if ((ct = malloc(sizeof(dctpLpCmdThread))) == NULL) {
	    rsp.cmd = DCTPLP_NACKRSP;
	    rsp.rsperrno = errno;
	    writeRsp(&rsp, sizeof(rsp), (struct sockaddr_un *)(pkt->appdata));
	    dctpoPacketFree(pkt);
	    return;
	} else {
	    ct->pid = pid;
	    pthread_mutex_init(&(ct->mutex), NULL);
	    pthread_cond_init(&(ct->cond), NULL);
	    ct->first = NULL;
	    ct->last = NULL;
	    if (createThread(cmdThread, ct) < 0) {
		free(ct);
		rsp.cmd = DCTPLP_NACKRSP;
		rsp.rsperrno = errno;
		writeRsp(&rsp, sizeof(rsp), (struct sockaddr_un *)(pkt->appdata));
		dctpoPacketFree(pkt);
		return;
	    }
	    ct->next = cmdThreads[hash];
	    cmdThreads[hash] = ct;
	}
    }
    pkt->nextPkt = NULL;
    pthread_mutex_lock(&(ct->mutex));
    if (ct->last) {
	ct->last->nextPkt = pkt;
    } else {
	ct->first = pkt;
    }
    ct->last = pkt;
    pthread_cond_broadcast(&(ct->cond));
    pthread_mutex_unlock(&(ct->mutex));
}

static void *readCmd(void *arg) {
    dctpPacket *pkt;
    int mlen, flen;
    dctpLpGenCmd *cmd;

    /* Read commands and do them */
    while (1) {
	if ((pkt = dctpoPacketMalloc(DCTPLP_MAXPKTSIZE + sizeof(struct sockaddr_un))) == NULL) {
	    dctpoLog(DCTPLOG_DEBUG, "readCmd: can't get packet buffer\n");
	    sleep(1);
	    continue;
	}
	flen = sizeof(struct sockaddr_un);
	cmd = (dctpLpGenCmd *)(pkt->appdata + sizeof(struct sockaddr_un));
	if ((mlen = recvfrom(cmdsd, cmd, DCTPLP_MAXPKTSIZE, 0,
			     (struct sockaddr *)(pkt->appdata), &flen)) < 0) {
	    dctpoLog(DCTPLOG_ERR, "readCmd: error reading from command socket: %s\n",
		     strerror(errno));
	    dctpoPacketFree(pkt);
	    sleep(1);
	    continue;
	}
	pkt->appdatalen = mlen + sizeof(struct sockaddr_un);
	dispatchCmd(pkt, cmd->pid, cmd->cmd);
    }
}

void dctpoApiInit(void) {
    pthread_t t;
    pthread_attr_t attr;

    openCmdSock();
    pthread_attr_init(&attr);
    pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED);
    if (pthread_create(&t, &attr, readCmd, NULL) != 0) {
	dctpoLog(DCTPLOG_ERR, "dctpoApiInit: can't start readCmd thread: %s\n",
		 strerror(errno));
    }
}

static void writeRspMsg(void *rsp, int rlen, void *buf, int blen, struct sockaddr_un *addr) {
    struct msghdr mhdr;
    struct iovec iov[2];

    iov[0].iov_base = rsp;
    iov[0].iov_len = rlen;
    iov[1].iov_base = buf;
    iov[1].iov_len = blen;
    mhdr.msg_name = addr;
    mhdr.msg_namelen = sizeof(struct sockaddr_un);
    mhdr.msg_iov = iov;
    mhdr.msg_iovlen = 2;
    mhdr.msg_control = NULL;
    mhdr.msg_controllen = 0;
    mhdr.msg_flags = 0;
    if (sendmsg(cmdsd, &mhdr, 0) < 0) {
	dctpoLog(DCTPLOG_ERR, "writeRspMsg: error sending to %s: %s\n",
		 addr->sun_path, strerror(errno));
    }
}

#define DCTPLP_MAXSOCKETS  32767

static dctpSocket *sockptrs[DCTPLP_MAXSOCKETS];
static int sdallocs[DCTPLP_MAXSOCKETS];
static int nextSockdes = 1;

static int badsd(int sd, int *cerrno) {
    if ((sd <= 0) || (sd >= DCTPLP_MAXSOCKETS) || (sockptrs[sd] == NULL)) {
	*cerrno = EBADF;
	return(-1);
    }
    return(0);
}

static int badaddr(dctpSocket *sock, dctpSockaddr *addr, int *cerrno) {
    switch (sock->encap) {
    case DCTPENCAP_RAW:
    case DCTPENCAP_NAT:
	if (addr->sa_family != DCTP_AF_INET) {
	    *cerrno = EINVAL;
	    return(-1);
	}
	break;
    case DCTPENCAP_RAW6:
    case DCTPENCAP_NAT6:
	if (addr->sa_family != DCTP_AF_INET6) {
	    *cerrno = EINVAL;
	    return(-1);
	}
	break;
    default:
	*cerrno = EINVAL;
	return(-1);
    }
    return(0);
}

static int getSocketDescriptor(void) {
    int sd;

    sd = nextSockdes;
    if (sd < 0) {
	return(-1);
    }
    sdallocs[sd] = 1;
    do {
	if (++nextSockdes >= DCTPLP_MAXSOCKETS) nextSockdes = 1;
	if (nextSockdes == sd) {
	    nextSockdes = -1;
	    break;
	}
    } while (sdallocs[nextSockdes] == 1);
    return(sd);
}

static void freeSocketDescriptor(int sd) {
    sdallocs[sd] = 0;
    sockptrs[sd] = NULL;
    if (nextSockdes < 0) nextSockdes = sd;
}

static void socketOpen(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpSocketCmd *cmd = (dctpLpSocketCmd *)(raddr + 1);
    dctpLpSocketRsp rsp;
    int encap;
    int sd;
    dctpSocket *sock;

    if ((sd = getSocketDescriptor()) < 0) {
	rsp.rsperrno = ENFILE;
	goto sendNack;
    }
    if (cmd->domain == DCTP_PF_INET) {
	encap = (cmd->type == DCTPLP_SOCKRAW) ? DCTPENCAP_RAW : DCTPENCAP_NAT;
    } else {
	encap = (cmd->type == DCTPLP_SOCKRAW) ? DCTPENCAP_RAW6 : DCTPENCAP_NAT6;
    }
    if ((sock = dctpaSocket(encap, cmd->scode, &(rsp.rsperrno))) == NULL) {
	freeSocketDescriptor(sd);
	goto sendNack;
    }
    sockptrs[sd] = sock;
    rsp.cmd = DCTPLP_ACKRSP;
    rsp.sd = sd;
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
    return;

 sendNack:
    rsp.cmd = DCTPLP_NACKRSP;
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

static void socketClose(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpGenCmd *cmd = (dctpLpGenCmd *)(raddr + 1);
    dctpLpAckRsp rsp;

    if (!badsd(cmd->sd, &(rsp.rsperrno))) {
	dctpaClose(sockptrs[cmd->sd], NULL);
	freeSocketDescriptor(cmd->sd);
	rsp.cmd = DCTPLP_ACKRSP;
    } else {
	rsp.cmd = DCTPLP_NACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

static void socketBind(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpBindCmd *cmd = (dctpLpBindCmd *)(raddr + 1);
    dctpLpAckRsp rsp;

    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (badaddr(sockptrs[cmd->sd], &(cmd->addr), &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (dctpaBind(sockptrs[cmd->sd], &(cmd->addr), &(rsp.rsperrno)) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else {
	rsp.cmd = DCTPLP_ACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

static void socketListen(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpListenCmd *cmd = (dctpLpListenCmd *)(raddr + 1);
    dctpLpAckRsp rsp;

    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (dctpaListen(sockptrs[cmd->sd], cmd->backlog, &(rsp.rsperrno)) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else {
	rsp.cmd = DCTPLP_ACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

static void socketConnect(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpConnectCmd *cmd = (dctpLpConnectCmd *)(raddr + 1);
    dctpLpAckRsp rsp;

    pkt->appdata = (uint8_t *)(cmd + 1);
    pkt->appdatalen -= pkt->appdata - ((uint8_t *)raddr);
    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (badaddr(sockptrs[cmd->sd], &(cmd->addr), &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (dctpaConnect(sockptrs[cmd->sd], &(cmd->addr), 
			    (pkt->appdatalen > 0) ? pkt : NULL, &(rsp.rsperrno)) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else {
	rsp.cmd = DCTPLP_ACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    if (rsp.cmd == DCTPLP_NACKRSP) dctpoPacketFree(pkt);
}

static void socketAccept(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpGenCmd *cmd = (dctpLpGenCmd *)(raddr + 1);
    dctpLpAcceptRsp rsp;
    dctpSocket *nsock;

    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if ((nsock = dctpaAccept(sockptrs[cmd->sd], &(rsp.addr),
				    sizeof(dctpSockaddr), &(rsp.rsperrno))) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if ((rsp.nsd = getSocketDescriptor()) < 0) {
	dctpaClose(nsock, NULL);
	rsp.cmd = DCTPLP_NACKRSP;
	rsp.rsperrno = ENFILE;
    } else {
	sockptrs[rsp.nsd] = nsock;
	rsp.cmd = DCTPLP_ACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

static void socketSend(dctpPacket *pkt) {
    struct sockaddr_un raddr = *((struct sockaddr_un *)(pkt->appdata));
    dctpLpGenCmd *cmd = (dctpLpGenCmd *)(pkt->appdata + sizeof(struct sockaddr_un));
    dctpLpAckRsp rsp;

    pkt->appdata += sizeof(struct sockaddr_un) + sizeof(dctpLpGenCmd);
    pkt->appdatalen -= sizeof(struct sockaddr_un) + sizeof(dctpLpGenCmd);
    pkt->dccphdr = pkt->appdata;
    pkt->iphdr = pkt->appdata;
    dctpoObjectLock(&readThreadLock);   /* Make sure input gets caught up */
    dctpoObjectUnlock(&readThreadLock);
    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (dctpaSend(sockptrs[cmd->sd], pkt, &(rsp.rsperrno)) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else {
	rsp.cmd = DCTPLP_ACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), &raddr);
    if (rsp.cmd == DCTPLP_NACKRSP) dctpoPacketFree(pkt);
}

static void socketRecv(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpGenCmd *cmd = (dctpLpGenCmd *)(raddr + 1);
    dctpLpRecvRsp rsp;
    dctpPacket *data;

    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (dctpaRecv(sockptrs[cmd->sd], &data, &(rsp.rsperrno)) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else {
	rsp.cmd = DCTPLP_ACKRSP;
	rsp.mlen = data->appdatalen;
	writeRspMsg(&rsp, sizeof(rsp), data->appdata, data->appdatalen, raddr);
	dctpoPacketFree(pkt);
	dctpoPacketFree(data);
	return;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

static void socketSetSockOpt(dctpPacket *pkt) {
    struct sockaddr_un *raddr = (struct sockaddr_un *)(pkt->appdata);
    dctpLpSetOptCmd *cmd = (dctpLpSetOptCmd *)(raddr + 1);
    dctpLpAckRsp rsp;

    if (badsd(cmd->sd, &(rsp.rsperrno))) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else if (dctpaSetSockopt(sockptrs[cmd->sd], cmd->optname, cmd->mandatory,
			       cmd->optval, cmd->optlen, &(rsp.rsperrno)) < 0) {
	rsp.cmd = DCTPLP_NACKRSP;
    } else {
	rsp.cmd = DCTPLP_ACKRSP;
    }
    writeRsp(&rsp, sizeof(rsp), raddr);
    dctpoPacketFree(pkt);
}

