collect more statistics; add sanity check to avoid bad icmp packets
[unix-history] / usr / src / sys / netinet / ip_input.c
index 0e96cea..37f7c27 100644 (file)
@@ -1,16 +1,18 @@
-/*     ip_input.c      1.57    82/10/30        */
+/*     ip_input.c      1.65    83/02/23        */
 
 #include "../h/param.h"
 #include "../h/systm.h"
 #include "../h/mbuf.h"
 
 #include "../h/param.h"
 #include "../h/systm.h"
 #include "../h/mbuf.h"
+#include "../h/domain.h"
 #include "../h/protosw.h"
 #include "../h/socket.h"
 #include "../h/protosw.h"
 #include "../h/socket.h"
-#include <errno.h>
-#include <time.h>
+#include "../h/errno.h"
+#include "../h/time.h"
 #include "../h/kernel.h"
 
 #include "../net/if.h"
 #include "../net/route.h"
 #include "../h/kernel.h"
 
 #include "../net/if.h"
 #include "../net/route.h"
+
 #include "../netinet/in.h"
 #include "../netinet/in_pcb.h"
 #include "../netinet/in_systm.h"
 #include "../netinet/in.h"
 #include "../netinet/in_pcb.h"
 #include "../netinet/in_systm.h"
@@ -36,11 +38,12 @@ ip_init()
        if (pr == 0)
                panic("ip_init");
        for (i = 0; i < IPPROTO_MAX; i++)
        if (pr == 0)
                panic("ip_init");
        for (i = 0; i < IPPROTO_MAX; i++)
-               ip_protox[i] = pr - protosw;
-       for (pr = protosw; pr <= protoswLAST; pr++)
+               ip_protox[i] = pr - inetsw;
+       for (pr = inetdomain.dom_protosw;
+           pr <= inetdomain.dom_protoswNPROTOSW; pr++)
                if (pr->pr_family == PF_INET &&
                    pr->pr_protocol && pr->pr_protocol != IPPROTO_RAW)
                if (pr->pr_family == PF_INET &&
                    pr->pr_protocol && pr->pr_protocol != IPPROTO_RAW)
-                       ip_protox[pr->pr_protocol] = pr - protosw;
+                       ip_protox[pr->pr_protocol] = pr - inetsw;
        ipq.next = ipq.prev = &ipq;
        ip_id = time.tv_sec & 0xffff;
        ipintrq.ifq_maxlen = ipqmaxlen;
        ipq.next = ipq.prev = &ipq;
        ip_id = time.tv_sec & 0xffff;
        ipintrq.ifq_maxlen = ipqmaxlen;
@@ -76,17 +79,20 @@ next:
        if (m == 0)
                return;
        if ((m->m_off > MMAXOFF || m->m_len < sizeof (struct ip)) &&
        if (m == 0)
                return;
        if ((m->m_off > MMAXOFF || m->m_len < sizeof (struct ip)) &&
-           (m = m_pullup(m, sizeof (struct ip))) == 0)
-               return;
+           (m = m_pullup(m, sizeof (struct ip))) == 0) {
+               ipstat.ips_toosmall++;
+               goto next;
+       }
        ip = mtod(m, struct ip *);
        if ((hlen = ip->ip_hl << 2) > m->m_len) {
        ip = mtod(m, struct ip *);
        if ((hlen = ip->ip_hl << 2) > m->m_len) {
-               if ((m = m_pullup(m, hlen)) == 0)
-                       return;
+               if ((m = m_pullup(m, hlen)) == 0) {
+                       ipstat.ips_badhlen++;
+                       goto next;
+               }
                ip = mtod(m, struct ip *);
        }
        if (ipcksum)
                if (ip->ip_sum = in_cksum(m, hlen)) {
                ip = mtod(m, struct ip *);
        }
        if (ipcksum)
                if (ip->ip_sum = in_cksum(m, hlen)) {
-                       printf("ip_sum %x\n", ip->ip_sum);      /* XXX */
                        ipstat.ips_badsum++;
                        goto bad;
                }
                        ipstat.ips_badsum++;
                        goto bad;
                }
@@ -95,6 +101,10 @@ next:
         * Convert fields to host representation.
         */
        ip->ip_len = ntohs((u_short)ip->ip_len);
         * Convert fields to host representation.
         */
        ip->ip_len = ntohs((u_short)ip->ip_len);
+       if (ip->ip_len < hlen) {
+               ipstat.ips_badlen++;
+               goto bad;
+       }
        ip->ip_id = ntohs(ip->ip_id);
        ip->ip_off = ntohs((u_short)ip->ip_off);
 
        ip->ip_id = ntohs(ip->ip_id);
        ip->ip_off = ntohs((u_short)ip->ip_off);
 
@@ -148,6 +158,20 @@ next:
                    sin->sin_addr.s_addr == ip->ip_dst.s_addr)
                        goto ours;
        }
                    sin->sin_addr.s_addr == ip->ip_dst.s_addr)
                        goto ours;
        }
+/* BEGIN GROT */
+#include "nd.h"
+#if NND > 0
+       /*
+        * Diskless machines don't initially know
+        * their address, so take packets from them
+        * if we're acting as a network disk server.
+        */
+       if (ip->ip_dst.s_addr == INADDR_ANY &&
+           (in_netof(ip->ip_src) == INADDR_ANY &&
+            in_lnaof(ip->ip_src) != INADDR_ANY))
+               goto ours;
+#endif
+/* END GROT */
        ipaddr.sin_addr = ip->ip_dst;
        if (if_ifwithaddr((struct sockaddr *)&ipaddr) == 0) {
                ip_forward(ip);
        ipaddr.sin_addr = ip->ip_dst;
        if (if_ifwithaddr((struct sockaddr *)&ipaddr) == 0) {
                ip_forward(ip);
@@ -192,12 +216,12 @@ found:
                m = dtom(ip);
        } else
                if (fp)
                m = dtom(ip);
        } else
                if (fp)
-                       (void) ip_freef(fp);
+                       ip_freef(fp);
 
        /*
         * Switch out to protocol's input routine.
         */
 
        /*
         * Switch out to protocol's input routine.
         */
-       (*protosw[ip_protox[ip->ip_p]].pr_input)(m);
+       (*inetsw[ip_protox[ip->ip_p]].pr_input)(m);
        goto next;
 bad:
        m_freem(m);
        goto next;
 bad:
        m_freem(m);
@@ -232,7 +256,7 @@ ip_reass(ip, fp)
         * If first fragment to arrive, create a reassembly queue.
         */
        if (fp == 0) {
         * If first fragment to arrive, create a reassembly queue.
         */
        if (fp == 0) {
-               if ((t = m_get(M_WAIT)) == NULL)
+               if ((t = m_get(M_WAIT, MT_FTABLE)) == NULL)
                        goto dropfrag;
                fp = mtod(t, struct ipq *);
                insque(fp, &ipq);
                        goto dropfrag;
                fp = mtod(t, struct ipq *);
                insque(fp, &ipq);
@@ -342,20 +366,18 @@ dropfrag:
  * Free a fragment reassembly header and all
  * associated datagrams.
  */
  * Free a fragment reassembly header and all
  * associated datagrams.
  */
-struct ipq *
 ip_freef(fp)
        struct ipq *fp;
 {
 ip_freef(fp)
        struct ipq *fp;
 {
-       register struct ipasfrag *q;
-       struct mbuf *m;
+       register struct ipasfrag *q, *p;
 
 
-       for (q = fp->ipq_next; q != (struct ipasfrag *)fp; q = q->ipf_next)
+       for (q = fp->ipq_next; q != (struct ipasfrag *)fp; q = p) {
+               p = q->ipf_next;
+               ip_deq(q);
                m_freem(dtom(q));
                m_freem(dtom(q));
-       m = dtom(fp);
-       fp = fp->next;
-       remque(fp->prev);
-       (void) m_free(m);
-       return (fp);
+       }
+       remque(fp);
+       (void) m_free(dtom(fp));
 }
 
 /*
 }
 
 /*
@@ -398,11 +420,12 @@ ip_slowtimo()
                splx(s);
                return;
        }
                splx(s);
                return;
        }
-       while (fp != &ipq)
-               if (--fp->ipq_ttl == 0)
-                       fp = ip_freef(fp);
-               else
-                       fp = fp->next;
+       while (fp != &ipq) {
+               --fp->ipq_ttl;
+               fp = fp->next;
+               if (fp->prev->ipq_ttl == 0)
+                       ip_freef(fp->prev);
+       }
        splx(s);
 }
 
        splx(s);
 }
 
@@ -413,7 +436,7 @@ ip_drain()
 {
 
        while (ipq.next != &ipq)
 {
 
        while (ipq.next != &ipq)
-               (void) ip_freef(ipq.next);
+               ip_freef(ipq.next);
 }
 
 /*
 }
 
 /*
@@ -505,7 +528,7 @@ ip_dooptions(ip)
 
                        case IPOPT_TS_PRESPEC:
                                ipaddr.sin_addr = *sin;
 
                        case IPOPT_TS_PRESPEC:
                                ipaddr.sin_addr = *sin;
-                               if (!if_ifwithaddr((struct sockaddr *)&ipaddr))
+                               if (if_ifwithaddr((struct sockaddr *)&ipaddr) == 0)
                                        continue;
                                if (ipt->ipt_ptr + 8 > ipt->ipt_len)
                                        goto bad;
                                        continue;
                                if (ipt->ipt_ptr + 8 > ipt->ipt_len)
                                        goto bad;
@@ -610,8 +633,8 @@ ip_forward(ip)
                goto sendicmp;
        }
        ip->ip_ttl -= IPTTLDEC;
                goto sendicmp;
        }
        ip->ip_ttl -= IPTTLDEC;
-       mopt = m_get(M_DONTWAIT);
-       if (mopt == 0) {
+       mopt = m_get(M_DONTWAIT, MT_DATA);
+       if (mopt == NULL) {
                m_freem(dtom(ip));
                return;
        }
                m_freem(dtom(ip));
                return;
        }