don't allow process to attach to itself
[unix-history] / usr / src / sys / kern / uipc_socket.c
index 733be43..1fe0828 100644 (file)
@@ -4,21 +4,21 @@
  *
  * %sccs.include.redist.c%
  *
  *
  * %sccs.include.redist.c%
  *
- *     @(#)uipc_socket.c       7.25 (Berkeley) %G%
+ *     @(#)uipc_socket.c       7.37 (Berkeley) %G%
  */
 
  */
 
-#include "param.h"
-#include "user.h"
-#include "proc.h"
-#include "file.h"
-#include "malloc.h"
-#include "mbuf.h"
-#include "domain.h"
-#include "kernel.h"
-#include "protosw.h"
-#include "socket.h"
-#include "socketvar.h"
-#include "time.h"
+#include <sys/param.h>
+#include <sys/systm.h>
+#include <sys/proc.h>
+#include <sys/file.h>
+#include <sys/malloc.h>
+#include <sys/mbuf.h>
+#include <sys/domain.h>
+#include <sys/kernel.h>
+#include <sys/protosw.h>
+#include <sys/socket.h>
+#include <sys/socketvar.h>
+#include <sys/resourcevar.h>
 
 /*
  * Socket operation routines.
 
 /*
  * Socket operation routines.
  * sys_socket.c or from a system process, and
  * implement the semantics of socket operations by
  * switching out to the protocol specific routines.
  * sys_socket.c or from a system process, and
  * implement the semantics of socket operations by
  * switching out to the protocol specific routines.
- *
- * TODO:
- *     test socketpair
- *     clean up async
- *     out-of-band is a kludge
  */
 /*ARGSUSED*/
 socreate(dom, aso, type, proto)
  */
 /*ARGSUSED*/
 socreate(dom, aso, type, proto)
+       int dom;
        struct socket **aso;
        register int type;
        int proto;
 {
        struct socket **aso;
        register int type;
        int proto;
 {
+       struct proc *p = curproc;               /* XXX */
        register struct protosw *prp;
        register struct socket *so;
        register int error;
        register struct protosw *prp;
        register struct socket *so;
        register int error;
@@ -53,7 +50,7 @@ socreate(dom, aso, type, proto)
        MALLOC(so, struct socket *, sizeof(*so), M_SOCKET, M_WAIT);
        bzero((caddr_t)so, sizeof(*so));
        so->so_type = type;
        MALLOC(so, struct socket *, sizeof(*so), M_SOCKET, M_WAIT);
        bzero((caddr_t)so, sizeof(*so));
        so->so_type = type;
-       if (u.u_uid == 0)
+       if (p->p_ucred->cr_uid == 0)
                so->so_state = SS_PRIV;
        so->so_proto = prp;
        error =
                so->so_state = SS_PRIV;
        so->so_proto = prp;
        error =
@@ -261,6 +258,7 @@ bad:
        return (error);
 }
 
        return (error);
 }
 
+#define        SBLOCKWAIT(f)   (((f) & MSG_DONTWAIT) ? M_NOWAIT : M_WAITOK)
 /*
  * Send on a socket.
  * If send must go all at once and message is larger than
 /*
  * Send on a socket.
  * If send must go all at once and message is larger than
@@ -286,6 +284,7 @@ sosend(so, addr, uio, top, control, flags)
        struct mbuf *control;
        int flags;
 {
        struct mbuf *control;
        int flags;
 {
+       struct proc *p = curproc;               /* XXX */
        struct mbuf **mp;
        register struct mbuf *m;
        register long space, len, resid;
        struct mbuf **mp;
        register struct mbuf *m;
        register long space, len, resid;
@@ -299,13 +298,13 @@ sosend(so, addr, uio, top, control, flags)
        dontroute =
            (flags & MSG_DONTROUTE) && (so->so_options & SO_DONTROUTE) == 0 &&
            (so->so_proto->pr_flags & PR_ATOMIC);
        dontroute =
            (flags & MSG_DONTROUTE) && (so->so_options & SO_DONTROUTE) == 0 &&
            (so->so_proto->pr_flags & PR_ATOMIC);
-       u.u_ru.ru_msgsnd++;
+       p->p_stats->p_ru.ru_msgsnd++;
        if (control)
                clen = control->m_len;
 #define        snderr(errno)   { error = errno; splx(s); goto release; }
 
 restart:
        if (control)
                clen = control->m_len;
 #define        snderr(errno)   { error = errno; splx(s); goto release; }
 
 restart:
-       if (error = sblock(&so->so_snd))
+       if (error = sblock(&so->so_snd, SBLOCKWAIT(flags)))
                goto out;
        do {
                s = splnet();
                goto out;
        do {
                s = splnet();
@@ -315,7 +314,8 @@ restart:
                        snderr(so->so_error);
                if ((so->so_state & SS_ISCONNECTED) == 0) {
                        if (so->so_proto->pr_flags & PR_CONNREQUIRED) {
                        snderr(so->so_error);
                if ((so->so_state & SS_ISCONNECTED) == 0) {
                        if (so->so_proto->pr_flags & PR_CONNREQUIRED) {
-                               if ((so->so_state & SS_ISCONFIRMING) == 0)
+                               if ((so->so_state & SS_ISCONFIRMING) == 0 &&
+                                   !(resid == 0 && clen != 0))
                                        snderr(ENOTCONN);
                        } else if (addr == 0)
                                snderr(EDESTADDRREQ);
                                        snderr(ENOTCONN);
                        } else if (addr == 0)
                                snderr(EDESTADDRREQ);
@@ -323,11 +323,11 @@ restart:
                space = sbspace(&so->so_snd);
                if (flags & MSG_OOB)
                        space += 1024;
                space = sbspace(&so->so_snd);
                if (flags & MSG_OOB)
                        space += 1024;
-               if (space < resid + clen &&
+               if (atomic && resid > so->so_snd.sb_hiwat ||
+                   clen > so->so_snd.sb_hiwat)
+                       snderr(EMSGSIZE);
+               if (space < resid + clen && uio &&
                    (atomic || space < so->so_snd.sb_lowat || space < clen)) {
                    (atomic || space < so->so_snd.sb_lowat || space < clen)) {
-                       if (atomic && resid > so->so_snd.sb_hiwat ||
-                           clen > so->so_snd.sb_hiwat)
-                               snderr(EMSGSIZE);
                        if (so->so_state & SS_NBIO)
                                snderr(EWOULDBLOCK);
                        sbunlock(&so->so_snd);
                        if (so->so_state & SS_NBIO)
                                snderr(EWOULDBLOCK);
                        sbunlock(&so->so_snd);
@@ -488,7 +488,7 @@ bad:
                    (struct mbuf *)0, (struct mbuf *)0);
 
 restart:
                    (struct mbuf *)0, (struct mbuf *)0);
 
 restart:
-       if (error = sblock(&so->so_rcv))
+       if (error = sblock(&so->so_rcv, SBLOCKWAIT(flags)))
                return (error);
        s = splnet();
 
                return (error);
        s = splnet();
 
@@ -499,24 +499,36 @@ restart:
         *   1. the current count is less than the low water mark, or
         *   2. MSG_WAITALL is set, and it is possible to do the entire
         *      receive operation at once if we block (resid <= hiwat).
         *   1. the current count is less than the low water mark, or
         *   2. MSG_WAITALL is set, and it is possible to do the entire
         *      receive operation at once if we block (resid <= hiwat).
+        *   3. MSG_DONTWAIT is not set
         * If MSG_WAITALL is set but resid is larger than the receive buffer,
         * we have to do the receive in sections, and thus risk returning
         * a short count if a timeout or signal occurs after we start.
         */
         * If MSG_WAITALL is set but resid is larger than the receive buffer,
         * we have to do the receive in sections, and thus risk returning
         * a short count if a timeout or signal occurs after we start.
         */
-       if (m == 0 || so->so_rcv.sb_cc < uio->uio_resid &&
+       if (m == 0 || ((flags & MSG_DONTWAIT) == 0 &&
+           so->so_rcv.sb_cc < uio->uio_resid) &&
            (so->so_rcv.sb_cc < so->so_rcv.sb_lowat ||
            (so->so_rcv.sb_cc < so->so_rcv.sb_lowat ||
-           ((flags & MSG_WAITALL) && uio->uio_resid <= so->so_rcv.sb_hiwat))) {
+           ((flags & MSG_WAITALL) && uio->uio_resid <= so->so_rcv.sb_hiwat)))
+               if (m && (m->m_nextpkt || (m->m_flags & M_EOR) ||
+                         m->m_type == MT_OOBDATA || m->m_type == MT_CONTROL))
+                       break;
 #ifdef DIAGNOSTIC
                if (m == 0 && so->so_rcv.sb_cc)
                        panic("receive 1");
 #endif
                if (so->so_error) {
 #ifdef DIAGNOSTIC
                if (m == 0 && so->so_rcv.sb_cc)
                        panic("receive 1");
 #endif
                if (so->so_error) {
+                       if (m)
+                               goto dontblock;
                        error = so->so_error;
                        error = so->so_error;
-                       so->so_error = 0;
+                       if ((flags & MSG_PEEK) == 0)
+                               so->so_error = 0;
                        goto release;
                }
                        goto release;
                }
-               if (so->so_state & SS_CANTRCVMORE)
-                       goto release;
+               if (so->so_state & SS_CANTRCVMORE) {
+                       if (m)
+                               goto dontblock;
+                       else
+                               goto release;
+               }
                if ((so->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0 &&
                    (so->so_proto->pr_flags & PR_CONNREQUIRED)) {
                        error = ENOTCONN;
                if ((so->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0 &&
                    (so->so_proto->pr_flags & PR_CONNREQUIRED)) {
                        error = ENOTCONN;
@@ -524,7 +536,7 @@ restart:
                }
                if (uio->uio_resid == 0)
                        goto release;
                }
                if (uio->uio_resid == 0)
                        goto release;
-               if (so->so_state & SS_NBIO) {
+               if ((so->so_state & SS_NBIO) || (flags & MSG_DONTWAIT)) {
                        error = EWOULDBLOCK;
                        goto release;
                }
                        error = EWOULDBLOCK;
                        goto release;
                }
@@ -535,8 +547,10 @@ restart:
                        return (error);
                goto restart;
        }
                        return (error);
                goto restart;
        }
-       u.u_ru.ru_msgrcv++;
+       if (uio->uio_procp)
+               uio->uio_procp->p_stats->p_ru.ru_msgrcv++;
        nextrecord = m->m_nextpkt;
        nextrecord = m->m_nextpkt;
+       record_eor = m->m_flags & M_EOR;
        if (pr->pr_flags & PR_ADDR) {
 #ifdef DIAGNOSTIC
                if (m->m_type != MT_SONAME)
        if (pr->pr_flags & PR_ADDR) {
 #ifdef DIAGNOSTIC
                if (m->m_type != MT_SONAME)
@@ -623,8 +637,6 @@ restart:
                } else
                        uio->uio_resid -= len;
                if (len == m->m_len - moff) {
                } else
                        uio->uio_resid -= len;
                if (len == m->m_len - moff) {
-                       if (m->m_flags & M_EOR)
-                               flags |= MSG_EOR;
                        if (flags & MSG_PEEK) {
                                m = m->m_next;
                                moff = 0;
                        if (flags & MSG_PEEK) {
                                m = m->m_next;
                                moff = 0;
@@ -664,8 +676,10 @@ restart:
                        } else
                                offset += len;
                }
                        } else
                                offset += len;
                }
-               if (flags & MSG_EOR)
+               if (m == 0 && record_eor) {
+                       flags |= record_eor;
                        break;
                        break;
+               }
                /*
                 * If the MSG_WAITALL flag is set (for non-atomic socket),
                 * we must not quit until "uio->uio_resid == 0" or an error
                /*
                 * If the MSG_WAITALL flag is set (for non-atomic socket),
                 * we must not quit until "uio->uio_resid == 0" or an error
@@ -674,18 +688,19 @@ restart:
                 * Keep sockbuf locked against other readers.
                 */
                while (flags & MSG_WAITALL && m == 0 && uio->uio_resid > 0 &&
                 * Keep sockbuf locked against other readers.
                 */
                while (flags & MSG_WAITALL && m == 0 && uio->uio_resid > 0 &&
-                   !sosendallatonce(so)) {
+                  !(flags & MSG_OOB) && !sosendallatonce(so)) {
+                       if (so->so_error || so->so_state & SS_CANTRCVMORE)
+                               break;
                        error = sbwait(&so->so_rcv);
                        if (error) {
                                sbunlock(&so->so_rcv);
                                splx(s);
                                return (0);
                        }
                        error = sbwait(&so->so_rcv);
                        if (error) {
                                sbunlock(&so->so_rcv);
                                splx(s);
                                return (0);
                        }
-                       if (m = so->so_rcv.sb_mb)
+                       if (m = so->so_rcv.sb_mb) {
                                nextrecord = m->m_nextpkt;
                                nextrecord = m->m_nextpkt;
-                       if (so->so_error || so->so_state & SS_CANTRCVMORE)
-                               break;
-                       continue;
+                               record_eor |= m->m_flags & M_EOR;
+                       }
                }
        }
        if ((flags & MSG_PEEK) == 0) {
                }
        }
        if ((flags & MSG_PEEK) == 0) {
@@ -732,7 +747,7 @@ sorflush(so)
        struct sockbuf asb;
 
        sb->sb_flags |= SB_NOINTR;
        struct sockbuf asb;
 
        sb->sb_flags |= SB_NOINTR;
-       (void) sblock(sb);
+       (void) sblock(sb, M_WAITOK);
        s = splimp();
        socantrcvmore(so);
        sbunlock(sb);
        s = splimp();
        socantrcvmore(so);
        sbunlock(sb);
@@ -774,6 +789,7 @@ sosetopt(so, level, optname, m0)
                case SO_USELOOPBACK:
                case SO_BROADCAST:
                case SO_REUSEADDR:
                case SO_USELOOPBACK:
                case SO_BROADCAST:
                case SO_REUSEADDR:
+               case SO_REUSEPORT:
                case SO_OOBINLINE:
                        if (m == NULL || m->m_len < sizeof (int)) {
                                error = EINVAL;
                case SO_OOBINLINE:
                        if (m == NULL || m->m_len < sizeof (int)) {
                                error = EINVAL;
@@ -847,6 +863,10 @@ sosetopt(so, level, optname, m0)
                        error = ENOPROTOOPT;
                        break;
                }
                        error = ENOPROTOOPT;
                        break;
                }
+               m = 0;
+               if (error == 0 && so->so_proto && so->so_proto->pr_ctloutput)
+                       (void) ((*so->so_proto->pr_ctloutput)
+                                 (PRCO_SETOPT, so, level, optname, &m0));
        }
 bad:
        if (m)
        }
 bad:
        if (m)
@@ -885,6 +905,7 @@ sogetopt(so, level, optname, mp)
                case SO_DEBUG:
                case SO_KEEPALIVE:
                case SO_REUSEADDR:
                case SO_DEBUG:
                case SO_KEEPALIVE:
                case SO_REUSEADDR:
+               case SO_REUSEPORT:
                case SO_BROADCAST:
                case SO_OOBINLINE:
                        *mtod(m, int *) = so->so_options & optname;
                case SO_BROADCAST:
                case SO_OOBINLINE:
                        *mtod(m, int *) = so->so_options & optname;
@@ -946,9 +967,5 @@ sohasoutofband(so)
                gsignal(-so->so_pgid, SIGURG);
        else if (so->so_pgid > 0 && (p = pfind(so->so_pgid)) != 0)
                psignal(p, SIGURG);
                gsignal(-so->so_pgid, SIGURG);
        else if (so->so_pgid > 0 && (p = pfind(so->so_pgid)) != 0)
                psignal(p, SIGURG);
-       if (so->so_rcv.sb_sel) {
-               selwakeup(so->so_rcv.sb_sel, so->so_rcv.sb_flags & SB_COLL);
-               so->so_rcv.sb_sel = 0;
-               so->so_rcv.sb_flags &= ~SB_COLL;
-       }
+       selwakeup(&so->so_rcv.sb_sel);
 }
 }