check for loops in compressed names
[unix-history] / usr / src / lib / libc / net / res_comp.c
index 256b5fe..60bf4a2 100644 (file)
@@ -4,31 +4,30 @@
  * specifies the terms and conditions for redistribution.
  */
 
  * specifies the terms and conditions for redistribution.
  */
 
-#ifndef lint
-static char sccsid[] = "@(#)res_comp.c 6.1 (Berkeley) %G%";
-#endif not lint
+#if defined(LIBC_SCCS) && !defined(lint)
+static char sccsid[] = "@(#)res_comp.c 6.9 (Berkeley) %G%";
+#endif LIBC_SCCS and not lint
 
 #include <sys/types.h>
 #include <stdio.h>
 
 #include <sys/types.h>
 #include <stdio.h>
-#include <ctype.h>
 #include <arpa/nameser.h>
 
 
 /*
  * Expand compressed domain name 'comp_dn' to full domain name.
 #include <arpa/nameser.h>
 
 
 /*
  * Expand compressed domain name 'comp_dn' to full domain name.
- * Expanded names are converted to lower case.
  * 'msg' is a pointer to the begining of the message,
  * 'msg' is a pointer to the begining of the message,
+ * 'eomorig' points to the first location after the message,
  * 'exp_dn' is a pointer to a buffer of size 'length' for the result.
  * Return size of compressed name or -1 if there was an error.
  */
  * 'exp_dn' is a pointer to a buffer of size 'length' for the result.
  * Return size of compressed name or -1 if there was an error.
  */
-dn_expand(msg, comp_dn, exp_dn, length)
-       char *msg, *comp_dn, *exp_dn;
+dn_expand(msg, eomorig, comp_dn, exp_dn, length)
+       char *msg, *eomorig, *comp_dn, *exp_dn;
        int length;
 {
        register char *cp, *dn;
        register int n, c;
        char *eom;
        int length;
 {
        register char *cp, *dn;
        register int n, c;
        char *eom;
-       int len = -1;
+       int len = -1, checked = 0;
 
        dn = exp_dn;
        cp = comp_dn;
 
        dn = exp_dn;
        cp = comp_dn;
@@ -49,23 +48,33 @@ dn_expand(msg, comp_dn, exp_dn, length)
                        }
                        if (dn+n >= eom)
                                return (-1);
                        }
                        if (dn+n >= eom)
                                return (-1);
-                       while (--n >= 0)
-                               if (isupper(c = *cp++))
-                                       *dn++ = tolower(c);
-                               else {
-                                       if (c == '.') {
-                                               if (dn+n+1 >= eom)
-                                                       return (-1);
-                                               *dn++ = '\\';
-                                       }
-                                       *dn++ = c;
+                       checked += n + 1;
+                       while (--n >= 0) {
+                               if ((c = *cp++) == '.') {
+                                       if (dn+n+1 >= eom)
+                                               return (-1);
+                                       *dn++ = '\\';
                                }
                                }
+                               *dn++ = c;
+                               if (cp >= eomorig)      /* out of range */
+                                       return(-1);
+                       }
                        break;
 
                case INDIR_MASK:
                        if (len < 0)
                                len = cp - comp_dn + 1;
                        cp = msg + (((n & 0x3f) << 8) | (*cp & 0xff));
                        break;
 
                case INDIR_MASK:
                        if (len < 0)
                                len = cp - comp_dn + 1;
                        cp = msg + (((n & 0x3f) << 8) | (*cp & 0xff));
+                       if (cp < msg || cp >= eomorig)  /* out of range */
+                               return(-1);
+                       checked += 2;
+                       /*
+                        * Check for loops in the compressed name;
+                        * if we've looked at the whole message,
+                        * there must be a loop.
+                        */
+                       if (checked >= eomorig - msg)
+                               return (-1);
                        break;
 
                default:
                        break;
 
                default:
@@ -102,7 +111,7 @@ dn_comp(exp_dn, comp_dn, length, dnptrs, lastdnptr)
 
        dn = exp_dn;
        cp = comp_dn;
 
        dn = exp_dn;
        cp = comp_dn;
-       eob = comp_dn + length;
+       eob = cp + length;
        if (dnptrs != NULL) {
                if ((msg = *dnptrs++) != NULL) {
                        for (cpp = dnptrs; *cpp != NULL; cpp++)
        if (dnptrs != NULL) {
                if ((msg = *dnptrs++) != NULL) {
                        for (cpp = dnptrs; *cpp != NULL; cpp++)
@@ -118,7 +127,7 @@ dn_comp(exp_dn, comp_dn, length, dnptrs, lastdnptr)
                                if (cp+1 >= eob)
                                        return (-1);
                                *cp++ = (l >> 8) | INDIR_MASK;
                                if (cp+1 >= eob)
                                        return (-1);
                                *cp++ = (l >> 8) | INDIR_MASK;
-                               *cp++ = l;
+                               *cp++ = l % 256;
                                return (cp - comp_dn);
                        }
                        /* not found, save it */
                                return (cp - comp_dn);
                        }
                        /* not found, save it */
@@ -196,7 +205,7 @@ dn_find(exp_dn, msg, dnptrs, lastdnptr)
        register int n;
        char *sp;
 
        register int n;
        char *sp;
 
-       for (cpp = dnptrs; cpp < lastdnptr; cpp++) {
+       for (cpp = dnptrs + 1; cpp < lastdnptr; cpp++) {
                dn = exp_dn;
                sp = cp = *cpp;
                while (n = *cp++) {
                dn = exp_dn;
                sp = cp = *cpp;
                while (n = *cp++) {
@@ -238,18 +247,25 @@ dn_find(exp_dn, msg, dnptrs, lastdnptr)
  */
 
 u_short
  */
 
 u_short
-getshort(msgp)
+_getshort(msgp)
        char *msgp;
 {
        register u_char *p = (u_char *) msgp;
        char *msgp;
 {
        register u_char *p = (u_char *) msgp;
+#ifdef vax
+       /*
+        * vax compiler doesn't put shorts in registers
+        */
+       register u_long u;
+#else
        register u_short u;
        register u_short u;
+#endif
 
        u = *p++ << 8;
 
        u = *p++ << 8;
-       return (u | *p);
+       return ((u_short)(u | *p));
 }
 
 u_long
 }
 
 u_long
-getlong(msgp)
+_getlong(msgp)
        char *msgp;
 {
        register u_char *p = (u_char *) msgp;
        char *msgp;
 {
        register u_char *p = (u_char *) msgp;