#include <stdio.h>
#include <iodef.h>
#include <brkdef.h>
#include <descrip.h>
#include <errno.h>
#include <ctype.h>
#include <stdlib.h>
#include <string.h>
#include "global.h"
#include "nstype.h"
#include "defines.h"

#define NS_PORT 53
#define DOMLEN 64    /* 63 octects for a domain name */
#ifndef MAXBRK
#define MAXBRK 4096  /* Max numb characters we can read from socket at once */
#endif

/* ------------------------------------------------------------------------ */
/* Change the host ip from string format to numerical */

void ns_good_host(int_addr, host)
  unsigned char int_addr[4];
  char *host;
{
  (void)sscanf(host, "%d.%d.%d.%d", &int_addr[0], &int_addr[1], &int_addr[2],
               &int_addr[3]);
}

/* ------------------------------------------------------------------------ */
/* Initialize our connection to the name server */

short start_client(nameserver, port, sin)
  char *nameserver;
  unsigned short port;
  struct sockaddr_in_type *sin;   /* The sin for our connection */
{
  int mark;
  short chan;
  unsigned hid;
  char msg[80];

  sin->sin_family = AF_INET;
  sin->sin_port = tcp_htons(port);
  sin->sin_zero[0] = 0;
  sin->sin_zero[1] = 0;
  ns_good_host(&hid, nameserver);
  sin->sin_address = hid;
  if ((chan = tcp_socket(AF_INET, SOCK_STREAM, PF_INET)) == -1) {
    get_socket_error(msg);
    (void)say("nsl socket: %s (%s)\n", nameserver, msg);
    return 0;
  }
  if (tcp_connect(chan, sin, sizeof(*sin)) == -1) {
    get_socket_error(msg);
    (void)say("nsl connect: %s (%s)\n", nameserver, msg);
    tcp_shutdown(chan);
    return 0;
  }
  return chan;
}

/* ------------------------------------------------------------------------ */

struct ddomain {
  unsigned short offset;
  char domain[DOMLEN];
  struct ddomain *next;
};

struct ddomain *top = NULL;
char *top_of_record = NULL;

/* ------------------------------------------------------------------------ */
/* Deallocate all of the memory that we have reserved */

void kill_domain_list(void)
{
  struct ddomain *temp;
  while (top != NULL)
  {
    temp = top;
    top = top->next;
    free(temp);
  }
}

/* ------------------------------------------------------------------------ */
/* Put a record in the linked list to specify the end of a domain */

int set_domain_end(void)
{
  struct ddomain *tt;
  if ((tt = (struct ddomain *)malloc(sizeof(struct ddomain))) == NULL)
    return FALSE;

  tt->offset = 65535;
  tt->domain[0] = 0;
  tt->next = top;
  top = tt;
  return TRUE;
}

/* ------------------------------------------------------------------------ */

int getdomain(buffer, my_index, output)
  char *buffer;
  int *my_index;
  char *output;
{
  unsigned short wptr;
  char newlen, wnext;
  struct ddomain *temp, *walk, *back;
  int ind;

  output[0] = 0;
  ind = *my_index;
  (void)memcpy(&newlen, &buffer[ind], sizeof(newlen));
  ind += sizeof(newlen);

  while (newlen != 0) {
    if ((newlen & 192) == 192) /* First 2 bits for an unsigned char */
    {
      newlen -= 192;
      wnext = buffer[ind++];
      wptr = newlen*256+wnext; /* Calculate offset (from ID field) */
      newlen = 0;                /* NO more domains to follow */

      for (walk=top; (walk) && ((int)walk->offset != wptr); walk = walk->next)
        ;

      if (!walk) return -1;

      while (walk && ((int)walk->offset != 65535))
      {
        (void)strcat(output, walk->domain);
        if ((walk->next) && ((int)walk->next->offset != 65535))
          (void)strcat(output, ".");
        walk = walk->next;
      }
    }
    else
    {
      (void)strncat(output, &buffer[ind], newlen);
      if ((temp = (struct ddomain *)malloc(sizeof(struct ddomain))) == NULL)
        return FALSE;
      temp->offset = ind - 1; /* Length is 1 */
      (void)strncpy(temp->domain, &buffer[ind], newlen);
      temp->domain[newlen] = 0;
/* Insert the field into the table */

      walk = back = top;
      while ((walk) && ((int)walk->offset != 65535)) {
        back = walk;
        walk = walk->next;
      }
      temp->next = walk;
      if (back == walk)
        top = temp;
      else
        back->next = temp;

/* done inserting it into the list */
      ind += newlen;
      (void)memcpy(&newlen, &buffer[ind], sizeof(newlen));
      ind += sizeof(newlen);
      if (newlen != 0) (void)strcat(output, ".");
    }
  }
  (void)set_domain_end();
  *my_index = ind;
  return TRUE;
}

/* ------------------------------------------------------------------------ */

void question(buffer, index, numchar)
  char *buffer;
  int *index, numchar;
{
  unsigned short class, type;
  char output[100];

/* What did we ask the question about */
  (void)getdomain(buffer, index, output); 
  (void)memcpy(&type, &buffer[*index], sizeof(type));
  *index += sizeof(type);
  (void)memcpy(&class, &buffer[*index], sizeof(class));
  *index += sizeof(class);
#ifdef DEBUG
  (void)say("Class: %d Type: %d [%s]\n", tcp_htons(class), tcp_htons(type), output);
#endif
}

/* ------------------------------------------------------------------------ */

int rrecord(buffer, host, realip, my_index)
  char *buffer, *host;
  DOMAIN realip;
  int *my_index;
{
  char byte;
  unsigned short type, class, rdlength, size;
  unsigned ttl;
  char output[256];
  int loop, ind, end_ind, good = FALSE;

  ind = *my_index;
  (void)getdomain(buffer, &ind, output);  /* The domain in question */

  (void)memcpy(&type, &buffer[ind], sizeof(type));
  ind += sizeof(type);
  type = tcp_htons(type);

  ind = ind + sizeof(class) + sizeof(ttl);

  (void)memcpy(&rdlength, &buffer[ind], sizeof(rdlength));
  ind += sizeof(rdlength);
  rdlength = tcp_htons(rdlength);

  end_ind = ind + rdlength;

/* We have to do different things depending upon the type of response */

  good = FALSE;
  switch ((int)type)
  {
    case T_A:
      for (loop=0; loop<MAXIP; loop++) realip[loop] = 0;
      for (loop=0; loop<4; loop++) {
         realip[loop] = buffer[ind++];
      }
      good = TRUE;
      break;
    case T_NS:
    case T_CNAME:
      (void)getdomain(buffer, &ind, output);
#ifdef DEBUG
      (void)say("Canocial name: [%s]\n", output);
#endif
      (void)strcpy(host, output);
      break;
    default:
      (void)memcpy(output, &buffer[ind], rdlength);
      ind += rdlength;
      break;
  }
  *my_index = end_ind;
  return good;
}

/* ------------------------------------------------------------------------ */

unsigned short examoutput(buffer, host, realip, numchar, server)
  char *buffer;                  /* The data from the nameserver */
  char *host;                    /* Host name that we are dealing with */
  DOMAIN realip;                 /* Hopefully the IP of that host name */
  int numchar;
  char *server;
{
  unsigned short word;           /* Used for POPping a word */
  char byte;                   /* Used for POPping a byte */
  char *mem;                   /* Keep tract of where we are in mem */
  int loop;                      /* Loop for number of answers/questons */
  unsigned short qdcount;        /* Returned number of questions */
  unsigned short ancount;        /* Returned number of answers */
  unsigned short nscount;        /* Returned number of name server responses */
  unsigned short arcount;        /* Returned number of additional records */
  unsigned short size;           /* Used to calculate size of POP instr. */
  unsigned short good;           /* Did we find a good host */
  struct HEADER *packet;
  unsigned short *crap;
  unsigned short temp;
  int index;
  char some_host[100];

  packet = (struct HEADER *)(buffer + 2);
  top_of_record = (char *)packet;
  qdcount = tcp_htons(packet->qdcount);
  ancount = tcp_htons(packet->ancount);
  nscount = tcp_htons(packet->nscount);
  arcount = tcp_htons(packet->arcount);

  good = 0;
  mem = (char *)((int)packet + sizeof(struct HEADER));
  index = sizeof(struct HEADER);

#ifdef DEBUG  /**/
  (void)say("There were: %d questions %d answers %d nameservers %d additional\n",
         qdcount, ancount, nscount, arcount);
#endif        /**/

/* We are finished with the header */
/* Now it is time to process the question */

  for (loop=0; loop<(int)qdcount; loop++)
    (void)question(top_of_record, &index, numchar);

/* And most importantly, process the answers to our question */

  for (loop=0; loop<(int)ancount; loop++) {
    if (rrecord(top_of_record, some_host, realip, &index)) return TRUE;
  }

  for (loop=0; loop<(int)nscount; loop++)
    if (rrecord(top_of_record, some_host, realip, &index)) return TRUE;

  for (loop=0; loop<(int)arcount; loop++)
    if (rrecord(top_of_record, some_host, realip, &index)) return TRUE;

  return good;
}

/* ------------------------------------------------------------------------ */
/* Prepare a question and send it out - according to RFC1035 */
/* ------------------------------------------------------------------------ */

void prep_packet(chan, mem, host, query_type)
  unsigned short chan;
  char *mem;
  char *host;
  unsigned short query_type;
{
  unsigned short id = 4, qdcount = 1, temp, class;
  char *loop, *orig, *len;
  int cut, size;
  struct HEADER *pac;
  char temp_s[80];
  DOMAIN temp_d;

  orig = mem;
  size = sizeof(struct HEADER);

  pac = (struct HEADER *)(mem+2);
  mem = (char *)((int)pac + sizeof(struct HEADER));

  id=4;
  pac->id = tcp_htons(id);

  pac->query = 0;    /* 1 bit  query = 0             */
  pac->opcode = 0;   /* 4 bits normal query = 0      */
  pac->aa = 0;       /* 1 bit  response only         */
  pac->tc = 0;       /* 1 bit  truncated = 1         */
  pac->rd = 1;       /* 1 bit  recursion desired = 1 */

  pac->ra = 0;       /* 1 bit  response only         */
  pac->z = 0;        /* 3 bits must be 0 (reserved)  */
  pac->rcode = 0;    /* 4 bits response code         */

  qdcount=1;
  pac->qdcount = tcp_htons(qdcount);
  pac->ancount = 0;
  pac->nscount = 0;
  pac->arcount = 0;

  len = mem++;
  size++;
  (void)strcpy(mem, host);
  for (cut = 0; *mem; mem++, size++)
    if (*mem=='.') {
      *len = cut;
      len = mem;
      cut = 0;
    }
    else cut++;
  *len = cut;
  *mem = 0;
  mem++; size++;

  temp = tcp_htons(query_type);
  (void)memcpy(mem, &temp, sizeof(query_type));
  size += sizeof(query_type);
  mem += sizeof(query_type);

  class = 1;     /* Class = INTERNET */
  temp = tcp_htons(class);
  (void)memcpy(mem, &temp, sizeof(temp));
  size += sizeof(temp);
  mem += sizeof(temp);

  temp = tcp_htons(size);   /* We are about to push the size on */
  (void)memcpy(orig, &temp, sizeof(temp));
  size += 2;
  (void)tcp_send(chan, orig, size);
 
#ifdef DEBUG
  for (mem=orig; size!=0; mem++, size--)
    if (isalpha(*mem)) (void)say("%c", *mem);
    else (void)say("%d|", *mem);
  (void)say("\n");
#endif
}

/* ------------------------------------------------------------------------ */

int nslookup(nameserver, host, realip)
  char *nameserver;
  char *host;
  DOMAIN realip;
{
  char t1[MAXBRK];
  short chan = 0;                 /* The chan assigned to the network dev. */
                                  /* that we will be read/write from. */
  char buffer[MAXBRK] = "\0\0";   /* The buffer used for the socket read */
  struct sockaddr_in_type sin;    /* The sin for our connection */
  unsigned short num_char, match;  /* Did we find a good match */
  struct iosb_type iosb;
  DOMAIN IP;
  int loop;
  match = FALSE;
  if (set_domain_end() != FALSE)
  {
    if ((chan = start_client(nameserver, NS_PORT, &sin)) == 0) return FALSE;
    prep_packet(chan, t1, host, T_A);
    num_char = tcp_receive(chan, buffer, sizeof(buffer), &iosb, 0, 0);
    match = examoutput(buffer, host , IP, num_char, nameserver);
    kill_domain_list();
  }
  for (loop=0;loop<4;loop++) realip[loop] = IP[loop];
  if (!*nameserver) exit(1);
  tcp_shutdown(chan);
  return match;
}

/* ------------------------------------------------------------------------ */

char all_nameserver[3][20] = {
"128.205.1.2",
"127.0.0.1",
"192.33.33.51",
};

int find_server(host, realip)
  char *host;
  DOMAIN realip;
{
  char t1[MAXBRK], msg[80];
  int count = 0;
  short chan;                     /* The chan assigned to the network dev. */
                                  /* that we will be read/write from. */
  char buffer[MAXBRK] = "\0\0";   /* The buffer used for the socket read */
  struct sockaddr_in_type sin;    /* The sin for our connection */
  char *realhost;
  int num_char, match;  /* Did we find a good match */
  char temp[80];
  struct iosb_type iosb;

#ifdef MULTINET
  if (!get_logical("MULTINET_NAMESERVERS", temp)) {
#else
#ifdef UCX
  if (!get_logical("UCX$BIND_SERVER000", temp)) {
#else
  if (1) {
#endif
#endif

  while ((chan = start_client(all_nameserver[count%3], NS_PORT, &sin)) == 0)
    if (++count == 20) return FALSE;
  match = FALSE;
  while ((!match) && (*host)) {
    if (set_domain_end() != FALSE)
    {
      (void)strcpy(temp, host);
      prep_packet(chan, t1, temp, T_NS);
      if ((num_char = tcp_receive(chan, buffer, sizeof(buffer), &iosb, 0, 0)) == -1)
      {
        get_socket_error(msg);
        (void)say("Error in receive: Chan: %d %s\n", chan, msg);
      }
      match = examoutput(buffer, temp , realip, num_char,
                         all_nameserver[count%3]);
      kill_domain_list();
      if (match) {
        tcp_shutdown(chan);
        return TRUE;
      }
      while ((*host) && (*host!='.')) host++;
      if (*host) host++;
    }
  }
  (void)sscanf(all_nameserver[count%3], "%d.%d.%d.%d", &realip[0], &realip[1],
         &realip[2], &realip[3]);
  tcp_shutdown(chan);
  return FALSE;
  } else {
     sscanf(temp, "%d.%d.%d.%d", &realip[0], &realip[1],
            &realip[2], &realip[3]);
     return TRUE;
  }
}
