#include <stdio.h>
#include <iodef.h>
#include <brkdef.h>
#include <descrip.h>
#include <ctype.h>
#include <stdlib.h>
#include <string.h>
#include "global.h"
#include "nstype.h"
#include "qio.h"

/*
#define DEBUG
/**/
#define TYPE_LOOKUP 1
#define TYPE_SERVER 2

#define NS_PORT 53
#ifndef MAXBRK
#define MAXBRK 4096  /* Max numb characters we can read from socket at once */
#endif

/* ------------------------------------------------------------------------ */
/* Change the host ip from string format to numerical */

static void ns_good_host(int_addr, host)
  unsigned char int_addr[4];
  char *host;
{
  unsigned char tmp;
  int loop;

  for (loop=0;loop<4;loop++) {
    tmp = 0;
    while (*host && (*host != '.'))
      tmp = tmp * 10 + *host++ - '0';
    if (*host == '.') host++;
    int_addr[loop] = tmp;
  }
}

/* ------------------------------------------------------------------------ */
/* Initialize our connection to the name server */

static short start_client(nameserver, port)
  char *nameserver;
  unsigned short port;
{
  int mark;
  short chan;
  unsigned char hid[4];
  char msg[80];
  struct sockaddr_in_type sin;

  ns_good_host(hid, nameserver);
  set_sin_tcp(&sin, (char *)hid, tcp_htons(port));
  if ((chan = tcp_socket(AF_INET, SOCK_STREAM, PF_INET)) == -1) {
    get_socket_error(msg);
    (void)say("nsl socket: %s (%s)\n", nameserver, msg);
    return 0;
  }
  if ((tcp_connect(chan, &sin, sizeof(sin)) == -1)) {
    get_socket_error(msg);
    (void)say("nsl connect: %s (%s)\n", nameserver, msg);
    tcp_shutdown(chan);
    return 0;
  }
  return chan;
}

/* ------------------------------------------------------------------------ */
void print_domain(unsigned char *top, int *ind, char *name)
{
  int xx, walk, curr = 0;
  walk = *ind;
  while (top[walk] != 0)
    if ((top[walk] & 192) == 192) {
      xx = top[walk++] - 192;
      xx = xx*256 + top[walk];
      print_domain(top, &xx, &name[curr]);
      *ind = ++walk;
      return;
    } else {
      for (xx=top[walk++]; xx>0; xx--) name[curr++] = top[walk++];
      if (top[walk]) name[curr++] = '.';
      else name[curr] = 0;
    }
  *ind = ++walk;
}

/* ------------------------------------------------------------------------ */
typedef struct name_s_entry name_entry;
struct name_s_entry {
  char *domain;
  int type;
  struct in_addr domain_ip;
  name_entry *next;
};

typedef struct name_s_list name_list;
struct name_s_list {
  int current_count, current_set;
  name_entry *list;
};
static int namelist_add(name_list *nl, char *host, int type);

/* ------------------------------------------------------------------------ */
static void question(buffer, index, numchar, nl)
  char *buffer;
  int *index, numchar;
  name_list *nl;
{
  unsigned short class, type;
  char output[100];

/* What did we ask the question about */
  print_domain((unsigned char *)buffer, index, output);
  (void)memcpy(&type, &buffer[*index], sizeof(type));
  *index += sizeof(type);
  (void)memcpy(&class, &buffer[*index], sizeof(class));
  *index += sizeof(class);
#ifdef DEBUG
  (void)printf("Question Class: %d Type: %d [%s]\n", tcp_htons(class),
            tcp_htons(type), output);
#endif
}

/* ------------------------------------------------------------------------ */
static name_list *namelist_create(void)
{
  name_list *tmp;
  tmp = (name_list *)malloc(sizeof(name_list));
  if (tmp) {
    tmp->current_count = tmp->current_set = 0;
    tmp->list = (name_entry *)NULL;
  }
  return tmp;
}

/* ------------------------------------------------------------------------ */
static int namelist_add(name_list *nlist, char *domain, int type)
{
  name_entry *top;

  if (!nlist) return;
  top = nlist->list;
  while (top && ((strcmp(top->domain, domain) != 0)))
    top = top->next;
  if (!top) {
    name_entry *tmp, *walk;
    tmp = (name_entry *)malloc(sizeof(name_entry));
    if (!tmp) return 0;
    memset(tmp, 0, sizeof(name_entry));
    tmp->domain = (char *)malloc(strlen(domain) + 1);
    if (tmp->domain) strcpy(tmp->domain, domain);
    tmp->type = type;
    tmp->next = (name_entry *)NULL;
    if (!(walk = nlist->list)) nlist->list = tmp;
    else {
      while (walk->next) walk = walk->next;
      walk->next = tmp;
    }
    nlist->current_count++;
  }
  return 1;
}

/* ------------------------------------------------------------------------ */
static void namelist_set(name_list *nlist, char *name, struct in_addr ip)
{
  int count = 0;
  name_entry *walk;
  if (!nlist) {
/*
    say("*** Fatal error - no namelist when setting ip");
*/
    return;
  }
  walk = nlist->list;
  while (walk && (strcmp(walk->domain, name) != 0)) walk = walk->next;
  if (walk) {
    memcpy(&walk->domain_ip, &ip, sizeof(ip));
    nlist->current_set++;
  } else {
    say("*** Fatal error - could not find a hostname for ip in namelist_set");
  }
}

/* ------------------------------------------------------------------------ */
static void namelist_destroy(name_list *nlist)
{
  name_entry *walk, *tmp;
  if (!nlist) return;
  walk = nlist->list;
  while (walk) {
    tmp = walk;
    walk = walk->next;
    free(tmp->domain);
    free(tmp);
  }
  free(nlist);
}

/* ------------------------------------------------------------------------ */
static int rrecord(buffer, host, realip, my_index, nl, isq)
  char *buffer, *host;
  DOMAIN realip;
  int *my_index;
  name_list *nl;
  int isq;
{
  char byte;
  DOMAIN tdomain;
  unsigned short type, class, rdlength, size, tshort;
  unsigned ttl, serial;
  unsigned char tbyte;
  char output[256], CPU[256], OS[256];
  int loop, ind, end_ind, good = FALSE;

  ind = *my_index;
  print_domain((unsigned char *)buffer, &ind, output);

  (void)memcpy(&type, &buffer[ind], sizeof(type));
  ind += sizeof(type);
  type = tcp_htons(type);

  ind = ind + sizeof(class) + sizeof(ttl);

  (void)memcpy(&rdlength, &buffer[ind], sizeof(rdlength));
  ind += sizeof(rdlength);
  rdlength = tcp_htons(rdlength);

  end_ind = ind + rdlength;

/* We have to do different things depending upon the type of response */

  good = FALSE;
  switch ((int)type)
  {
    case T_A:
      for (loop=0; loop<MAXIP; loop++) realip[loop] = 0;
      for (loop=0; loop<4; loop++) {
         realip[loop] = buffer[ind++];
      }
      {
        struct in_addr add;
        memcpy(&add, realip, sizeof(add));
        if (!isq) namelist_set(nl, output, add);
      }
      good = TRUE;
#ifdef DEBUG
      printf("Setting real ip: %d %d %d %d\n", realip[0], realip[1],
          realip[2], realip[3]);
#endif
      break;
    case T_NS:
      print_domain((unsigned char *)buffer, &ind, output);
      namelist_add(nl, output, T_NS);
      strcpy(host, output);
#ifdef DEBUG
      printf("T_NS Canocial name: [%s]\n", output);
#endif
      break;
    case T_CNAME:
      print_domain((unsigned char *)buffer, &ind, output);
      (void)strcpy(host, output);
#ifdef DEBUG
      printf("T_CNAME Canocial name: [%s]\n", output);
#endif
      break;
    case T_SOA:
#ifdef DEBUG
      print_domain((unsigned char *)buffer, &ind, output);
      printf("Source of SOA: %s\n", output);
      print_domain((unsigned char *)buffer, &ind, output);
      printf("Mailbox of SOA: %s\n", output);
      memcpy(&serial, &buffer[ind], sizeof(serial));
      ind += 4;
      printf("serial: %u\n", serial);
      memcpy(&serial, &buffer[ind], sizeof(serial));
      ind += 4;
      printf("refresh: %u\n", serial);
      memcpy(&serial, &buffer[ind], sizeof(serial));
      ind += 4;
      printf("retry refresh: %u\n", serial);
      memcpy(&serial, &buffer[ind], sizeof(serial));
      ind += 4;
      printf("expire: %u\n", serial);
#endif
      break;
    case T_WKS:
#ifdef DEBUG
      memcpy(tdomain, &buffer[ind], sizeof(tdomain));
      ind += sizeof(tdomain);
      printf("WKS record for %u.%u.%u.%u available\n", tdomain[0], tdomain[1],
             tdomain[2], tdomain[3]);
#endif
    break;
    case T_HINFO:
#ifdef DEBUG
    tbyte = buffer[ind++];
    memcpy(CPU, &buffer[ind], tbyte);
    CPU[tbyte] = 0;
    ind += tbyte;
    tbyte = buffer[ind++];
    memcpy(OS, &buffer[ind], tbyte);
    OS[tbyte] = 0;
    ind += tbyte;
    printf("CPU/OS = %s/%s\n", CPU, OS);
#endif
    break;
    case T_MX:
      memcpy(&tshort, &buffer[ind], sizeof(tshort));
      tshort = tcp_htons(tshort);
      ind += sizeof(tshort);
      print_domain((unsigned char *)buffer, &ind, output);
      namelist_add(nl, output, T_MX);
#ifdef DEBUG
      printf("MX = Pref: %d for %s\n", tshort, output);
#endif
      break;
    case T_MF:
      print_domain((unsigned char *)buffer, &ind, output);
      namelist_add(nl, output, T_MF);
    case T_MB:
      print_domain((unsigned char *)buffer, &ind, output);
      namelist_add(nl, output, T_MB);
      break;
    case T_MINFO:
    default:
#ifdef DEBUG
      say("Processing type: %d\n", type);
#endif
      break;
  }
  *my_index = end_ind;
  return good;
}

/* ------------------------------------------------------------------------ */
static unsigned short examoutput(chan, buffer, host, realip, numchar, server,
                                 want_type)
  short chan;
  char *buffer;                  /* The data from the nameserver */
  char *host;                    /* Host name that we are dealing with */
  DOMAIN realip;                 /* Hopefully the IP of that host name */
  int numchar;
  char *server;
  int want_type;
{
  int loop;                      /* Loop for number of answers/questons */
  unsigned short qdcount;        /* Returned number of questions */
  unsigned short ancount;        /* Returned number of answers */
  unsigned short nscount;        /* Returned number of name server responses */
  unsigned short arcount;        /* Returned number of additional records */
  unsigned short size;           /* Used to calculate size of POP instr. */
  struct HEADER *packet;
  int index, good = 0;
  char *top_of_record = NULL;
  char some_host[256];  /* 255 max length according to the rfc1035 */
  struct iosb_type iosb;
  short rsize;
  name_list *current_nl = NULL;

  memcpy(&rsize, buffer, 2);
  rsize = tcp_htons(rsize);

  while (numchar<rsize)
  {
    index = tcp_receive(chan, &buffer[numchar], rsize-numchar, &iosb, 0, 0);
    if (index <= 0) {
      say("*** Error - read truncated\n");
      return 0;
    }
    numchar += index;
  }
#ifdef DEBUG
  printf("Expecting %d bytes - have %d\n", rsize, numchar);
#endif

  current_nl = namelist_create();
  packet = (struct HEADER *)(buffer + 2);
  top_of_record = (char *)packet;
  qdcount = tcp_htons(packet->qdcount);
  ancount = tcp_htons(packet->ancount);
  nscount = tcp_htons(packet->nscount);
  arcount = tcp_htons(packet->arcount);

  index = sizeof(struct HEADER);

#ifdef DEBUG
  printf("Incoming packet (%d bytes)\n", numchar);
  for (loop=0;loop<numchar;loop++)
    if (isalpha(buffer[loop])) printf("%c", buffer[loop]);
    else printf("%d|", buffer[loop]);
  printf("\n");
  printf("Packet-Id: (%d) b1 (%d) b2 (%d)\n", packet->id, packet->b1, packet->b2);
  printf("There were: %d questions %d answers %d nameservers %d additional\n",
         qdcount, ancount, nscount, arcount);
#endif

/* We are finished with the header */
/* Now it is time to process the question */

  for (loop=0; loop<(int)qdcount; loop++)
    (void)question(top_of_record, &index, numchar, current_nl);

/* And most importantly, process the answers to our question */
#ifdef DEBUG
  printf("about to exam an responses\n");
#endif

  for (loop=0; loop<(int)ancount; loop++)
    if (rrecord(top_of_record, some_host, realip, &index, current_nl, 1))
      if (want_type == TYPE_LOOKUP) return 1;
  if (want_type == TYPE_LOOKUP) return 0;

#ifdef DEBUG
  printf("about to exam ns responses\n");
#endif
  for (loop=0; loop<(int)nscount; loop++)
    rrecord(top_of_record, some_host, realip, &index, current_nl, 0);
#ifdef DEBUG
  printf("about to exam additional records\n");
#endif
  for (loop=0; loop<(int)arcount; loop++)
    if (rrecord(top_of_record, some_host, realip, &index, current_nl, 0))
      if (want_type == TYPE_SERVER) good = 1;
  if (want_type == TYPE_SERVER)
  {
    name_entry *walk = current_nl->list;
    while (walk && (walk->type != T_NS)) walk = walk->next;
    if (walk)
      memcpy(realip, &walk->domain_ip, sizeof(walk->domain_ip));
  }
#ifdef DEBUG
  {
    name_entry *walk;
    for (walk = current_nl->list; walk; walk = walk->next)
    {
      printf("map: %s -> %d.%d.%d.%d\n", walk->domain,
             walk->domain_ip.S_un.S_un_b.s_b1,
             walk->domain_ip.S_un.S_un_b.s_b2,
             walk->domain_ip.S_un.S_un_b.s_b3,
             walk->domain_ip.S_un.S_un_b.s_b4);
    }
  }
#endif
  namelist_destroy(current_nl);
  return good;
}

/* ------------------------------------------------------------------------ */
/* Prepare a question and send it out - according to RFC1035 */
/* ------------------------------------------------------------------------ */
static void prep_packet(chan, mem, host, query_type)
  unsigned short chan;
  char *mem;
  char *host;
  unsigned short query_type;
{
  extern int time(int *);
  unsigned short id = 4, qdcount = 1, temp, class;
  char *loop, *orig, *len;
  int cut, size;
  struct HEADER *pac;
  char temp_s[80];
  DOMAIN temp_d;

  orig = mem;
  size = sizeof(struct HEADER);

  pac = (struct HEADER *)(mem+2);
  mem = (char *)((int)pac + sizeof(struct HEADER));

  id=time(NULL);
  pac->id = tcp_htons(id);

#ifndef BIT_CODE
  pac->b1 = 1;
  pac->b2 = 0;
#else
  pac->query = 0;    /* 1 bit  query = 0             */
  pac->opcode = 0;   /* 4 bits normal query = 0      */
  pac->aa = 0;       /* 1 bit  response only         */
  pac->tc = 0;       /* 1 bit  truncated = 1         */
  pac->rd = 1;       /* 1 bit  recursion desired = 1 */

  pac->ra = 0;       /* 1 bit  response only         */
  pac->z = 0;        /* 3 bits must be 0 (reserved)  */
  pac->rcode = 0;    /* 4 bits response code         */
#endif

  qdcount=1;
  pac->qdcount = tcp_htons(qdcount);
  pac->ancount = 0;
  pac->nscount = 0;
  pac->arcount = 0;

  len = mem++;
  size++;
  (void)strcpy(mem, host);
  for (cut = 0; *mem; mem++, size++)
    if (*mem=='.') {
      *len = cut;
      len = mem;
      cut = 0;
    }
    else cut++;
  *len = cut;
  *mem = 0;
  mem++; size++;

  temp = tcp_htons(query_type);
  (void)memcpy(mem, &temp, sizeof(query_type));
  size += sizeof(query_type);
  mem += sizeof(query_type);

  class = 1;     /* Class = INTERNET */
  temp = tcp_htons(class);
  (void)memcpy(mem, &temp, sizeof(temp));
  size += sizeof(temp);
  mem += sizeof(temp);

  temp = tcp_htons(size);   /* We are about to push the size on */
  (void)memcpy(orig, &temp, sizeof(temp));
  size += 2;
  (void)tcp_send(chan, orig, size);
 
#ifdef DEBUG
  printf("Question: \n");
  for (mem=orig; size!=0; mem++, size--)
    if (isalpha(*mem)) printf("%c", *mem);
    else printf("%d|", *mem);
  printf("\n");
#endif
}

/* ------------------------------------------------------------------------ */
int nslookup(nameserver, host, realip)
  char *nameserver;
  char *host;
  DOMAIN realip;
{
  char t1[MAXBRK];
  short chan = 0;                 /* The chan assigned to the network dev. */
                                  /* that we will be read/write from. */
  char buffer[MAXBRK] = "\0\0";   /* The buffer used for the socket read */
  unsigned short num_char, match;  /* Did we find a good match */
  struct iosb_type iosb;
  DOMAIN IP;
  int loop;
  match = FALSE;

  if ((chan = start_client(nameserver, NS_PORT)) == 0) return FALSE;
  prep_packet(chan, t1, host, T_A);
/* hanging here */
  num_char = tcp_receive(chan, buffer, sizeof(buffer), &iosb, 0, 0);
  match = examoutput(chan, buffer, host , IP, num_char, nameserver,
                     TYPE_LOOKUP);

  for (loop=0;loop<4;loop++) realip[loop] = IP[loop];
  tcp_shutdown(chan);
  return match;
}

/* ------------------------------------------------------------------------ */
static char all_nameserver[3][20] = {
{"128.205.1.2"},
{"127.0.0.1"},
{"192.33.33.51"},
};

int find_server(host, realip, shortcut)
  char *host;
  DOMAIN realip;
  int shortcut;
{
  char t1[MAXBRK], msg[80];
  int count = 0;
  short chan;                     /* The chan assigned to the network dev. */
                                  /* that we will be read/write from. */
  char buffer[MAXBRK] = "\0\0";   /* The buffer used for the socket read */
  char *realhost;
  int num_char, match;  /* Did we find a good match */
  char temp[80];
  struct iosb_type iosb;

#ifdef DEBUG
  printf("Paramater addresses: host(%d) realip(%d) shortcut(%d)\n",
      host, realip, shortcut);
#endif

#ifdef MULTINET
  if (!shortcut || !get_logical("MULTINET_NAMESERVERS", temp, TRUE)) {
#else
#ifdef UCX
  if (!shortcut || !get_logical("UCX$BIND_SERVER000", temp, TRUE)) {
#else
  if (1) {
#endif
#endif

#ifdef DEBUG_PARAM
  say("1aParamater addresses: host(%d) realip(%d) shortcut(%d)\n",
      host, realip, shortcut);
#endif
    while ((chan = start_client(all_nameserver[count%3], NS_PORT)) == 0)
      if (++count == 20) {
#ifdef DEBUG_PARAM
  say("1bParamater addresses: host(%d) realip(%d) shortcut(%d)\n",
      host, realip, shortcut);
#endif
        return FALSE;
      }
#ifdef DEBUG_PARAM
  say("1Paramater addresses: host(%d) realip(%d) shortcut(%d)\n",
      host, realip, shortcut);
#endif
    match = FALSE;
    while ((!match) && (*host)) {
      (void)strcpy(temp, host);
      prep_packet(chan, t1, temp, T_ANY);
      if ((num_char = tcp_receive(chan, buffer, sizeof(buffer),
                                  &iosb, 0, 0)) == -1)
      {
        get_socket_error(msg);
        (void)say("Error in receive: Chan: %d %s\n", chan, msg);
      }
      match = examoutput(chan, buffer, temp , realip, num_char,
                         all_nameserver[count%3], TYPE_SERVER);
      if (match) {
        tcp_shutdown(chan);
        return TRUE;
      }
      while ((*host) && (*host!='.')) host++;
      if (*host) host++;
    }
    ns_good_host(realip, all_nameserver[count%3]);
    tcp_shutdown(chan);
    return FALSE;
  }
  ns_good_host(realip, temp);
  return TRUE;
}
