//basic class for nodes to communicate with other nearby nodes

#include <pthread.h>
#include <assert.h>
#include <sys/socket.h>
#include <errno.h>
#include <sys/time.h>
#include "network_client.h"

static void *
receivethread(void *x)
{
  network_client *nc = (network_client *) x;
  nc->receive();
  return 0;
}

network_client::network_client() {
  
  //start the receive thread
  pthread_t th;
  int r = pthread_create(&th, NULL, &receivethread, (void *) this);
  assert (r == 0);

  msg_id = 0;

  //set up send socket
  send_sockfd = socket(AF_INET, SOCK_DGRAM, 0);
  assert(send_sockfd > 0);  
  
  //set up broadcast socket
  broadcast_sockfd = socket(AF_INET, SOCK_DGRAM, 0);
  assert(broadcast_sockfd > 0);
  int on = 1;
  setsockopt(broadcast_sockfd, SOL_SOCKET, SO_BROADCAST, &on, sizeof(on));
  
  struct sockaddr_in servaddr;
  bzero(&servaddr, sizeof(servaddr));
  servaddr.sin_family = AF_INET;
  servaddr.sin_port = htons(INCOMING_PORT); 
  servaddr.sin_addr.s_addr = BROADCAST_ADDRESS;
  r = connect(broadcast_sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr));
  assert(r == 0);

  struct sockaddr_in host;
  socklen_t size = sizeof(host); 
  r = getsockname(broadcast_sockfd, (struct sockaddr *)&host, &size);
  assert(r == 0);
  client_id = ntohl(host.sin_addr.s_addr);
  if (DEBUG)
    printf("DEBUG: (network client) client id: %u\n", client_id);
}

network_client::~network_client() {
  close(send_sockfd);
  close(broadcast_sockfd);
}

unsigned int
network_client::get_msg_id() {
  assert(pthread_mutex_lock(&msg_id_lock) == 0);
  int r = msg_id++;
  assert(pthread_mutex_unlock(&msg_id_lock) == 0);
  return r;
}

std::vector<node_id>
network_client::neighbors() {
  return peer_list;
}

bool
network_client::known_node(node_id id) {
  for(std::vector<node_id>::iterator it = peer_list.begin(); it != peer_list.end(); it++) {
    if (*it == id)
      return true;
  }
  return false;
}

void
network_client::attach_header(network_protocol::message_header *h, char *msg, int len, char *buf) {
  memmove((void*) buf, (void*) h, sizeof(network_protocol::message_header));
  memmove((void*) &(buf[sizeof(network_protocol::message_header)]), (void*) msg, len);
}

int 
network_client::basic_send(node_id dest, char *msg, int len) {
  if (DEBUG)
    printf("DEBUG: basic send to %u\n",dest);
  //we use UDP to send out the message
  struct sockaddr_in servaddr;
  bzero(&servaddr,sizeof(servaddr));
  servaddr.sin_family = AF_INET;
  
  //destination information (port/ip)
  servaddr.sin_port = htons(INCOMING_PORT);
  servaddr.sin_addr.s_addr = htonl(dest);
  
  //send the message
  int r = ::sendto(send_sockfd, msg, len, 0, 
                   (struct sockaddr*) &servaddr, sizeof(servaddr));
  //int err = errno;
  //printf("error: %s\n",strerror(err));
  return r;
}

int
network_client::basic_broadcast(char *msg, int len) {  
  if (DEBUG)
    printf("DEBUG: basic broadcast\n");
  int r = ::send(broadcast_sockfd, msg, len, 0);
  return r;
}


int 
network_client::join() {
  network_protocol::message_header h;
  h.source = client_id;
  h.dest = 0;
  h.id = get_msg_id();
  h.type = network_protocol::join;
  h.args = 0;
   
  //setup mechanisms to wait for reply
  pending_msg p;
  p.replied = false;
  p.replier = 0;
  assert(pthread_mutex_init(&(p.lock), 0) == 0);
  assert(pthread_cond_init(&(p.wait), 0) == 0);
  assert(pthread_mutex_lock(&(p.lock)) == 0);
  assert(pthread_mutex_lock(&pending_lock) == 0);
  pending_msgs[h.id] = &p;
  assert(pthread_mutex_unlock(&pending_lock) == 0);
  
  //send
  int r = basic_broadcast((char *) &h, sizeof(network_protocol::message_header));

  if (r < 0) {
    assert(pthread_mutex_lock(&pending_lock) == 0);
    pending_msgs.erase(h.id);
    assert(pthread_mutex_unlock(&pending_lock) == 0);
  
    return -(network_protocol::IOERROR);
  }

  //wait for a reply
  struct timeval now;
  struct timespec next_timeout;
  gettimeofday(&now, NULL);
  next_timeout.tv_sec = now.tv_sec + ACK_TIMEOUT;
  next_timeout.tv_nsec = 0;
  int l = pthread_cond_timedwait(&(p.wait),&(p.lock), &next_timeout);
  assert(pthread_mutex_unlock(&(p.lock)) == 0);
  pending_msgs.erase(h.id);
  if ((l == 0)&&(p.replied))
    //check if the list of members was filled in
    return 0;
  else
    return -(network_protocol::NO_PEERS);
}

int 
network_client::send(node_id dest, char *msg, int len) {

  if (!known_node(dest))
    return -(network_protocol::UNKNOWN_DEST);

  if (DEBUG) 
    printf("DEBUG: (default send) sending %s to %d\n",msg, dest);
  
  char buf[len+sizeof(network_protocol::message_header)];
  
  network_protocol::message_header h;
  h.source = client_id;
  h.dest = dest;
  h.id = get_msg_id();
  h.type = network_protocol::send;
  h.args = 0;
  attach_header(&h, msg, len, buf);
  
  len+=sizeof(network_protocol::message_header);
  
  // set up stuff to wait for the reply
  pending_msg p;
  p.replied = false;
  p.replier = dest;
  assert(pthread_mutex_init(&(p.lock), 0) == 0);
  assert(pthread_cond_init(&(p.wait), 0) == 0);
  assert(pthread_mutex_lock(&(p.lock)) == 0);
  assert(pthread_mutex_lock(&pending_lock) == 0);
  pending_msgs[h.id] = &p;
  assert(pthread_mutex_unlock(&pending_lock) == 0);
  
  //send the message
  int r = basic_send(dest, buf, len);
  
  if (r < 0) {
    assert(pthread_mutex_lock(&pending_lock) == 0);
    pending_msgs.erase(h.id);
    assert(pthread_mutex_unlock(&pending_lock) == 0);
  
    return -(network_protocol::IOERROR);
  }
  
  //wait for the reply
  struct timeval now;
  struct timespec next_timeout;
  gettimeofday(&now, NULL);
  next_timeout.tv_sec = now.tv_sec + ACK_TIMEOUT;
  next_timeout.tv_nsec = 0;
  printf("DEBUG: (default send) wait for %d seconds for ACK\n",ACK_TIMEOUT);
  int l = pthread_cond_timedwait(&(p.wait),&(p.lock), &next_timeout);
  assert(pthread_mutex_unlock(&(p.lock)) == 0);
  pending_msgs.erase(h.id);
  if ((l == 0) && (p.replied))
    //check if the list of members was filled in
    return r;
  else
    return -(network_protocol::IOERROR);
}

int
network_client::broadcast(int hops, char *msg, int len) {
  char buf[len+sizeof(network_protocol::message_header)];
  network_protocol::message_header h;
  h.source = client_id;
  h.dest = 0;
  h.id = get_msg_id();
  h.type = network_protocol::broadcast;
  h.args = hops-1;
  attach_header(&h, msg, len, buf);
  return basic_broadcast(buf, len+sizeof(network_protocol::message_header));
}

void
network_client::message_handler(node_id sender, network_protocol::message_header *h, 
                                char *msg, int len) {
  if (DEBUG)
    printf("DEBUG: (Default message handler) received from %u: %s\n", sender, msg);
  
  if (h->dest == client_id) {
    switch (h->type) {
    case network_protocol::send:
      //send ack
      if (DEBUG)
        printf("DEBUG: (Default message handler) send ack to %u\n", sender);
      network_protocol::message_header hr;
      hr.source = client_id;
      hr.dest = h->source;
      hr.id = get_msg_id();
      hr.type = network_protocol::send_ack;
      hr.args = h->id;
      basic_send(h->source, (char *)&hr, sizeof(network_protocol::message_header));
      break;
    case network_protocol::send_ack:
    case network_protocol::join_ack:
      //find
      if (DEBUG)
        printf("DEBUG: (Default message handler) received ack from %u\n", sender);
      if (pending_msgs.count(h->args) > 0) {
        pending_msg *p = pending_msgs[h->args];
        assert(pthread_mutex_lock(&(p->lock)) == 0);
        if ((p->replier == 0) || (p->replier == sender)) {  
          p->replied = true;
          p->replier = sender;
          assert(pthread_cond_signal(&(p->wait)) == 0);
        }
        assert(pthread_mutex_unlock(&(p->lock)) == 0);
      }
      break;      
    }
  } 
}

void 
network_client::receive() {
  int sockfd, n;
  unsigned int sender_ip;
  struct sockaddr_in receiver_addr, sender_addr;
  socklen_t len;
  
  //the maximum length of a UDP packet is ~64k. For safety, though, we constrain to ~32k
  char msg[32768];

  //bind to port 6829.
  sockfd = socket(AF_INET, SOCK_DGRAM, 0);
  bzero(&receiver_addr,sizeof(receiver_addr));
  receiver_addr.sin_family = AF_INET;
  receiver_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  receiver_addr.sin_port = htons(INCOMING_PORT);
  int r = bind(sockfd, (struct sockaddr *)&receiver_addr, sizeof(receiver_addr));
  assert(r >= 0);

  while (1) {
    len = sizeof(sender_addr);
    n = recvfrom(sockfd, msg, sizeof(msg)-1, 0,
                 (struct sockaddr *)&sender_addr, &len);
    if (DEBUG)
      printf("DEBUG: Receive message from %s: %s\n", inet_ntoa(sender_addr.sin_addr), 
             &(msg[sizeof(network_protocol::message_header)]));
    
    sender_ip = ntohl(sender_addr.sin_addr.s_addr);

    //if the message is from ourselves discard
    if (sender_ip == client_id)
      continue;

    network_protocol::message_header *h = (network_protocol::message_header*) msg;
    
    //if source is ourselves discard
    if (h->source == client_id)
      continue;
    
    //record that the message id was received
    if (received_msgs.count(h->source) > 0) {
      if (received_msgs[h->source]->count(h->id) > 0)
        continue;
      else
        received_msgs[h->source]->insert(h->id);
    } else {
      received_msgs[h->source] = new std::set<int>();
      received_msgs[h->source]->insert(h->id);
    }
    
    //strip off header
    n =- sizeof(network_protocol::message_header);
    
    //call message handler
    message_handler((node_id) sender_ip, h, 
                    &(msg[sizeof(network_protocol::message_header)]),n);
  }  
  close(sockfd);
}

