collect more statistics; add sanity check to avoid bad icmp packets
[unix-history] / usr / src / sys / netinet / ip_input.c
index cf9864f..37f7c27 100644 (file)
@@ -1,20 +1,25 @@
-/*     ip_input.c      1.51    82/10/09        */
+/*     ip_input.c      1.65    83/02/23        */
 
 #include "../h/param.h"
 #include "../h/systm.h"
 #include "../h/mbuf.h"
 
 #include "../h/param.h"
 #include "../h/systm.h"
 #include "../h/mbuf.h"
+#include "../h/domain.h"
 #include "../h/protosw.h"
 #include "../h/socket.h"
 #include "../h/protosw.h"
 #include "../h/socket.h"
+#include "../h/errno.h"
+#include "../h/time.h"
+#include "../h/kernel.h"
+
+#include "../net/if.h"
+#include "../net/route.h"
+
 #include "../netinet/in.h"
 #include "../netinet/in.h"
+#include "../netinet/in_pcb.h"
 #include "../netinet/in_systm.h"
 #include "../netinet/in_systm.h"
-#include "../net/if.h"
-#include "../netinet/ip.h"                     /* belongs before in.h */
+#include "../netinet/ip.h"
 #include "../netinet/ip_var.h"
 #include "../netinet/ip_icmp.h"
 #include "../netinet/tcp.h"
 #include "../netinet/ip_var.h"
 #include "../netinet/ip_icmp.h"
 #include "../netinet/tcp.h"
-#include <time.h>
-#include "../h/kernel.h"
-#include <errno.h>
 
 u_char ip_protox[IPPROTO_MAX];
 int    ipqmaxlen = IFQ_MAXLEN;
 
 u_char ip_protox[IPPROTO_MAX];
 int    ipqmaxlen = IFQ_MAXLEN;
@@ -33,11 +38,12 @@ ip_init()
        if (pr == 0)
                panic("ip_init");
        for (i = 0; i < IPPROTO_MAX; i++)
        if (pr == 0)
                panic("ip_init");
        for (i = 0; i < IPPROTO_MAX; i++)
-               ip_protox[i] = pr - protosw;
-       for (pr = protosw; pr <= protoswLAST; pr++)
+               ip_protox[i] = pr - inetsw;
+       for (pr = inetdomain.dom_protosw;
+           pr <= inetdomain.dom_protoswNPROTOSW; pr++)
                if (pr->pr_family == PF_INET &&
                    pr->pr_protocol && pr->pr_protocol != IPPROTO_RAW)
                if (pr->pr_family == PF_INET &&
                    pr->pr_protocol && pr->pr_protocol != IPPROTO_RAW)
-                       ip_protox[pr->pr_protocol] = pr - protosw;
+                       ip_protox[pr->pr_protocol] = pr - inetsw;
        ipq.next = ipq.prev = &ipq;
        ip_id = time.tv_sec & 0xffff;
        ipintrq.ifq_maxlen = ipqmaxlen;
        ipq.next = ipq.prev = &ipq;
        ip_id = time.tv_sec & 0xffff;
        ipintrq.ifq_maxlen = ipqmaxlen;
@@ -57,7 +63,7 @@ ipintr()
 {
        register struct ip *ip;
        register struct mbuf *m;
 {
        register struct ip *ip;
        register struct mbuf *m;
-       struct mbuf *m0, *mopt;
+       struct mbuf *m0;
        register int i;
        register struct ipq *fp;
        int hlen, s;
        register int i;
        register struct ipq *fp;
        int hlen, s;
@@ -73,29 +79,34 @@ next:
        if (m == 0)
                return;
        if ((m->m_off > MMAXOFF || m->m_len < sizeof (struct ip)) &&
        if (m == 0)
                return;
        if ((m->m_off > MMAXOFF || m->m_len < sizeof (struct ip)) &&
-           (m = m_pullup(m, sizeof (struct ip))) == 0)
-               return;
+           (m = m_pullup(m, sizeof (struct ip))) == 0) {
+               ipstat.ips_toosmall++;
+               goto next;
+       }
        ip = mtod(m, struct ip *);
        if ((hlen = ip->ip_hl << 2) > m->m_len) {
        ip = mtod(m, struct ip *);
        if ((hlen = ip->ip_hl << 2) > m->m_len) {
-               if ((m = m_pullup(m, hlen)) == 0)
-                       return;
+               if ((m = m_pullup(m, hlen)) == 0) {
+                       ipstat.ips_badhlen++;
+                       goto next;
+               }
                ip = mtod(m, struct ip *);
        }
        if (ipcksum)
                if (ip->ip_sum = in_cksum(m, hlen)) {
                ip = mtod(m, struct ip *);
        }
        if (ipcksum)
                if (ip->ip_sum = in_cksum(m, hlen)) {
-                       printf("ip_sum %x\n", ip->ip_sum);      /* XXX */
                        ipstat.ips_badsum++;
                        goto bad;
                }
 
                        ipstat.ips_badsum++;
                        goto bad;
                }
 
-#if vax
        /*
         * Convert fields to host representation.
         */
        ip->ip_len = ntohs((u_short)ip->ip_len);
        /*
         * Convert fields to host representation.
         */
        ip->ip_len = ntohs((u_short)ip->ip_len);
+       if (ip->ip_len < hlen) {
+               ipstat.ips_badlen++;
+               goto bad;
+       }
        ip->ip_id = ntohs(ip->ip_id);
        ip->ip_off = ntohs((u_short)ip->ip_off);
        ip->ip_id = ntohs(ip->ip_id);
        ip->ip_off = ntohs((u_short)ip->ip_off);
-#endif
 
        /*
         * Check that the amount of data in the buffers
 
        /*
         * Check that the amount of data in the buffers
@@ -147,6 +158,20 @@ next:
                    sin->sin_addr.s_addr == ip->ip_dst.s_addr)
                        goto ours;
        }
                    sin->sin_addr.s_addr == ip->ip_dst.s_addr)
                        goto ours;
        }
+/* BEGIN GROT */
+#include "nd.h"
+#if NND > 0
+       /*
+        * Diskless machines don't initially know
+        * their address, so take packets from them
+        * if we're acting as a network disk server.
+        */
+       if (ip->ip_dst.s_addr == INADDR_ANY &&
+           (in_netof(ip->ip_src) == INADDR_ANY &&
+            in_lnaof(ip->ip_src) != INADDR_ANY))
+               goto ours;
+#endif
+/* END GROT */
        ipaddr.sin_addr = ip->ip_dst;
        if (if_ifwithaddr((struct sockaddr *)&ipaddr) == 0) {
                ip_forward(ip);
        ipaddr.sin_addr = ip->ip_dst;
        if (if_ifwithaddr((struct sockaddr *)&ipaddr) == 0) {
                ip_forward(ip);
@@ -191,12 +216,12 @@ found:
                m = dtom(ip);
        } else
                if (fp)
                m = dtom(ip);
        } else
                if (fp)
-                       (void) ip_freef(fp);
+                       ip_freef(fp);
 
        /*
         * Switch out to protocol's input routine.
         */
 
        /*
         * Switch out to protocol's input routine.
         */
-       (*protosw[ip_protox[ip->ip_p]].pr_input)(m);
+       (*inetsw[ip_protox[ip->ip_p]].pr_input)(m);
        goto next;
 bad:
        m_freem(m);
        goto next;
 bad:
        m_freem(m);
@@ -231,7 +256,7 @@ ip_reass(ip, fp)
         * If first fragment to arrive, create a reassembly queue.
         */
        if (fp == 0) {
         * If first fragment to arrive, create a reassembly queue.
         */
        if (fp == 0) {
-               if ((t = m_get(M_WAIT)) == NULL)
+               if ((t = m_get(M_WAIT, MT_FTABLE)) == NULL)
                        goto dropfrag;
                fp = mtod(t, struct ipq *);
                insque(fp, &ipq);
                        goto dropfrag;
                fp = mtod(t, struct ipq *);
                insque(fp, &ipq);
@@ -341,20 +366,18 @@ dropfrag:
  * Free a fragment reassembly header and all
  * associated datagrams.
  */
  * Free a fragment reassembly header and all
  * associated datagrams.
  */
-struct ipq *
 ip_freef(fp)
        struct ipq *fp;
 {
 ip_freef(fp)
        struct ipq *fp;
 {
-       register struct ipasfrag *q;
-       struct mbuf *m;
+       register struct ipasfrag *q, *p;
 
 
-       for (q = fp->ipq_next; q != (struct ipasfrag *)fp; q = q->ipf_next)
+       for (q = fp->ipq_next; q != (struct ipasfrag *)fp; q = p) {
+               p = q->ipf_next;
+               ip_deq(q);
                m_freem(dtom(q));
                m_freem(dtom(q));
-       m = dtom(fp);
-       fp = fp->next;
-       remque(fp->prev);
-       (void) m_free(m);
-       return (fp);
+       }
+       remque(fp);
+       (void) m_free(dtom(fp));
 }
 
 /*
 }
 
 /*
@@ -397,11 +420,12 @@ ip_slowtimo()
                splx(s);
                return;
        }
                splx(s);
                return;
        }
-       while (fp != &ipq)
-               if (--fp->ipq_ttl == 0)
-                       fp = ip_freef(fp);
-               else
-                       fp = fp->next;
+       while (fp != &ipq) {
+               --fp->ipq_ttl;
+               fp = fp->next;
+               if (fp->prev->ipq_ttl == 0)
+                       ip_freef(fp->prev);
+       }
        splx(s);
 }
 
        splx(s);
 }
 
@@ -412,7 +436,7 @@ ip_drain()
 {
 
        while (ipq.next != &ipq)
 {
 
        while (ipq.next != &ipq)
-               (void) ip_freef(ipq.next);
+               ip_freef(ipq.next);
 }
 
 /*
 }
 
 /*
@@ -504,7 +528,7 @@ ip_dooptions(ip)
 
                        case IPOPT_TS_PRESPEC:
                                ipaddr.sin_addr = *sin;
 
                        case IPOPT_TS_PRESPEC:
                                ipaddr.sin_addr = *sin;
-                               if (!if_ifwithaddr((struct sockaddr *)&ipaddr))
+                               if (if_ifwithaddr((struct sockaddr *)&ipaddr) == 0)
                                        continue;
                                if (ipt->ipt_ptr + 8 > ipt->ipt_len)
                                        goto bad;
                                        continue;
                                if (ipt->ipt_ptr + 8 > ipt->ipt_len)
                                        goto bad;
@@ -563,7 +587,7 @@ ip_ctlinput(cmd, arg)
        int cmd;
        caddr_t arg;
 {
        int cmd;
        caddr_t arg;
 {
-       struct in_addr *sin;
+       struct in_addr *in;
        int tcp_abort(), udp_abort();
        extern struct inpcb tcb, udb;
 
        int tcp_abort(), udp_abort();
        extern struct inpcb tcb, udb;
 
@@ -572,13 +596,14 @@ ip_ctlinput(cmd, arg)
        if (inetctlerrmap[cmd] == 0)
                return;         /* XXX */
        if (cmd == PRC_IFDOWN)
        if (inetctlerrmap[cmd] == 0)
                return;         /* XXX */
        if (cmd == PRC_IFDOWN)
-               sin = &((struct sockaddr_in *)arg)->sin_addr;
+               in = &((struct sockaddr_in *)arg)->sin_addr;
        else if (cmd == PRC_HOSTDEAD || cmd == PRC_HOSTUNREACH)
        else if (cmd == PRC_HOSTDEAD || cmd == PRC_HOSTUNREACH)
-               sin = (struct in_addr *)arg;
+               in = (struct in_addr *)arg;
        else
        else
-               sin = &((struct icmp *)arg)->icmp_ip.ip_dst;
-       in_pcbnotify(&tcb, sin, inetctlerrmap[cmd], tcp_abort);
-       in_pcbnotify(&udb, sin, inetctlerrmap[cmd], udp_abort);
+               in = &((struct icmp *)arg)->icmp_ip.ip_dst;
+/* THIS IS VERY QUESTIONABLE, SHOULD HIT ALL PROTOCOLS */
+       in_pcbnotify(&tcb, in, (int)inetctlerrmap[cmd], tcp_abort);
+       in_pcbnotify(&udb, in, (int)inetctlerrmap[cmd], udp_abort);
 }
 
 int    ipprintfs = 0;
 }
 
 int    ipprintfs = 0;
@@ -608,8 +633,8 @@ ip_forward(ip)
                goto sendicmp;
        }
        ip->ip_ttl -= IPTTLDEC;
                goto sendicmp;
        }
        ip->ip_ttl -= IPTTLDEC;
-       mopt = m_get(M_DONTWAIT);
-       if (mopt == 0) {
+       mopt = m_get(M_DONTWAIT, MT_DATA);
+       if (mopt == NULL) {
                m_freem(dtom(ip));
                return;
        }
                m_freem(dtom(ip));
                return;
        }
@@ -622,7 +647,7 @@ ip_forward(ip)
        ip_stripoptions(ip, mopt);
 
        /* last 0 here means no directed broadcast */
        ip_stripoptions(ip, mopt);
 
        /* last 0 here means no directed broadcast */
-       if ((error = ip_output(dtom(ip), mopt, 0, 0)) == 0) {
+       if ((error = ip_output(dtom(ip), mopt, (struct route *)0, 0)) == 0) {
                if (mcopy)
                        m_freem(mcopy);
                return;
                if (mcopy)
                        m_freem(mcopy);
                return;