new template
[unix-history] / usr / src / libexec / tftpd / tftpd.c
index da3d85e..9a9529e 100644 (file)
@@ -11,12 +11,16 @@ char copyright[] =
 #endif not lint
 
 #ifndef lint
 #endif not lint
 
 #ifndef lint
-static char sccsid[] = "@(#)tftpd.c    5.1 (Berkeley) %G%";
+static char sccsid[] = "@(#)tftpd.c    5.6 (Berkeley) %G%";
 #endif not lint
 
 #endif not lint
 
+
 /*
  * Trivial file transfer protocol server.
 /*
  * Trivial file transfer protocol server.
+ *
+ * This version includes many modifications by Jim Guyton <guyton@rand-unix>
  */
  */
+
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/ioctl.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/ioctl.h>
@@ -42,7 +46,10 @@ struct       sockaddr_in sin = { AF_INET };
 int    peer;
 int    rexmtval = TIMEOUT;
 int    maxtimeout = 5*TIMEOUT;
 int    peer;
 int    rexmtval = TIMEOUT;
 int    maxtimeout = 5*TIMEOUT;
-char   buf[BUFSIZ];
+
+#define        PKTSIZE SEGSIZE+4
+char   buf[PKTSIZE];
+char   ackbuf[PKTSIZE];
 struct sockaddr_in from;
 int    fromlen;
 
 struct sockaddr_in from;
 int    fromlen;
 
@@ -50,33 +57,85 @@ main()
 {
        register struct tftphdr *tp;
        register int n;
 {
        register struct tftphdr *tp;
        register int n;
+       int on = 1;
 
 
-       alarm(10);
+       openlog("tftpd", LOG_PID, LOG_DAEMON);
+       if (ioctl(0, FIONBIO, &on) < 0) {
+               syslog(LOG_ERR, "ioctl(FIONBIO): %m\n");
+               exit(1);
+       }
        fromlen = sizeof (from);
        n = recvfrom(0, buf, sizeof (buf), 0,
            (caddr_t)&from, &fromlen);
        if (n < 0) {
        fromlen = sizeof (from);
        n = recvfrom(0, buf, sizeof (buf), 0,
            (caddr_t)&from, &fromlen);
        if (n < 0) {
-               perror("tftpd: recvfrom");
+               syslog(LOG_ERR, "recvfrom: %m\n");
                exit(1);
        }
                exit(1);
        }
+       /*
+        * Now that we have read the message out of the UDP
+        * socket, we fork and exit.  Thus, inetd will go back
+        * to listening to the tftp port, and the next request
+        * to come in will start up a new instance of tftpd.
+        *
+        * We do this so that inetd can run tftpd in "wait" mode.
+        * The problem with tftpd running in "nowait" mode is that
+        * inetd may get one or more successful "selects" on the
+        * tftp port before we do our receive, so more than one
+        * instance of tftpd may be started up.  Worse, if tftpd
+        * break before doing the above "recvfrom", inetd would
+        * spawn endless instances, clogging the system.
+        */
+       {
+               int pid;
+               int i, j;
+
+               for (i = 1; i < 20; i++) {
+                   pid = fork();
+                   if (pid < 0) {
+                               sleep(i);
+                               /*
+                                * flush out to most recently sent request.
+                                *
+                                * This may drop some request, but those
+                                * will be resent by the clients when
+                                * they timeout.  The positive effect of
+                                * this flush is to (try to) prevent more
+                                * than one tftpd being started up to service
+                                * a single request from a single client.
+                                */
+                               j = sizeof from;
+                               i = recvfrom(0, buf, sizeof (buf), 0,
+                                   (caddr_t)&from, &j);
+                               if (i > 0) {
+                                       n = i;
+                                       fromlen = j;
+                               }
+                   } else {
+                               break;
+                   }
+               }
+               if (pid < 0) {
+                       syslog(LOG_ERR, "fork: %m\n");
+                       exit(1);
+               } else if (pid != 0) {
+                       exit(0);
+               }
+       }
        from.sin_family = AF_INET;
        alarm(0);
        close(0);
        close(1);
        peer = socket(AF_INET, SOCK_DGRAM, 0);
        if (peer < 0) {
        from.sin_family = AF_INET;
        alarm(0);
        close(0);
        close(1);
        peer = socket(AF_INET, SOCK_DGRAM, 0);
        if (peer < 0) {
-               openlog("tftpd", LOG_PID, 0);
-               syslog(LOG_ERR, "socket: %m");
+               syslog(LOG_ERR, "socket: %m\n");
                exit(1);
        }
        if (bind(peer, (caddr_t)&sin, sizeof (sin)) < 0) {
                exit(1);
        }
        if (bind(peer, (caddr_t)&sin, sizeof (sin)) < 0) {
-               openlog("tftpd", LOG_PID, 0);
-               syslog(LOG_ERR, "bind: %m");
+               syslog(LOG_ERR, "bind: %m\n");
                exit(1);
        }
        if (connect(peer, (caddr_t)&from, sizeof(from)) < 0) {
                exit(1);
        }
        if (connect(peer, (caddr_t)&from, sizeof(from)) < 0) {
-               openlog("tftpd", LOG_PID, 0);
-               syslog(LOG_ERR, "connect: %m");
+               syslog(LOG_ERR, "connect: %m\n");
                exit(1);
        }
        tp = (struct tftphdr *)buf;
                exit(1);
        }
        tp = (struct tftphdr *)buf;
@@ -94,11 +153,12 @@ struct formats {
        int     (*f_validate)();
        int     (*f_send)();
        int     (*f_recv)();
        int     (*f_validate)();
        int     (*f_send)();
        int     (*f_recv)();
+       int     f_convert;
 } formats[] = {
 } formats[] = {
-       { "netascii",   validate_access,        sendfile,       recvfile },
-       { "octet",      validate_access,        sendfile,       recvfile },
+       { "netascii",   validate_access,        sendfile,       recvfile, 1 },
+       { "octet",      validate_access,        sendfile,       recvfile, 0 },
 #ifdef notdef
 #ifdef notdef
-       { "mail",       validate_user,          sendmail,       recvmail },
+       { "mail",       validate_user,          sendmail,       recvmail, 1 },
 #endif
        { 0 }
 };
 #endif
        { 0 }
 };
@@ -153,7 +213,8 @@ again:
        exit(0);
 }
 
        exit(0);
 }
 
-int    fd;
+
+FILE *file;
 
 /*
  * Validate file access.  Since we
 
 /*
  * Validate file access.  Since we
@@ -163,15 +224,16 @@ int       fd;
  * Note also, full path name must be
  * given as we have no login directory.
  */
  * Note also, full path name must be
  * given as we have no login directory.
  */
-validate_access(file, mode)
-       char *file;
+validate_access(filename, mode)
+       char *filename;
        int mode;
 {
        struct stat stbuf;
        int mode;
 {
        struct stat stbuf;
+       int     fd;
 
 
-       if (*file != '/')
+       if (*filename != '/')
                return (EACCESS);
                return (EACCESS);
-       if (stat(file, &stbuf) < 0)
+       if (stat(filename, &stbuf) < 0)
                return (errno == ENOENT ? ENOTFOUND : EACCESS);
        if (mode == RRQ) {
                if ((stbuf.st_mode&(S_IREAD >> 6)) == 0)
                return (errno == ENOENT ? ENOTFOUND : EACCESS);
        if (mode == RRQ) {
                if ((stbuf.st_mode&(S_IREAD >> 6)) == 0)
@@ -180,9 +242,13 @@ validate_access(file, mode)
                if ((stbuf.st_mode&(S_IWRITE >> 6)) == 0)
                        return (EACCESS);
        }
                if ((stbuf.st_mode&(S_IWRITE >> 6)) == 0)
                        return (EACCESS);
        }
-       fd = open(file, mode == RRQ ? 0 : 1);
+       fd = open(filename, mode == RRQ ? 0 : 1);
        if (fd < 0)
                return (errno + 100);
        if (fd < 0)
                return (errno + 100);
+       file = fdopen(fd, (mode == RRQ)? "r":"w");
+       if (file == NULL) {
+               return errno+100;
+       }
        return (0);
 }
 
        return (0);
 }
 
@@ -202,88 +268,143 @@ timer()
  * Send the requested file.
  */
 sendfile(pf)
  * Send the requested file.
  */
 sendfile(pf)
-       struct format *pf;
+       struct formats *pf;
 {
 {
-       register struct tftphdr *tp;
+       struct tftphdr *dp, *r_init();
+       register struct tftphdr *ap;    /* ack packet */
        register int block = 1, size, n;
 
        signal(SIGALRM, timer);
        register int block = 1, size, n;
 
        signal(SIGALRM, timer);
-       tp = (struct tftphdr *)buf;
+       dp = r_init();
+       ap = (struct tftphdr *)ackbuf;
        do {
        do {
-               size = read(fd, tp->th_data, SEGSIZE);
+               size = readit(file, &dp, pf->f_convert);
                if (size < 0) {
                        nak(errno + 100);
                if (size < 0) {
                        nak(errno + 100);
-                       return;
+                       goto abort;
                }
                }
-               tp->th_opcode = htons((u_short)DATA);
-               tp->th_block = htons((u_short)block);
+               dp->th_opcode = htons((u_short)DATA);
+               dp->th_block = htons((u_short)block);
                timeout = 0;
                (void) setjmp(timeoutbuf);
                timeout = 0;
                (void) setjmp(timeoutbuf);
-               if (send(peer, buf, size + 4, 0) != size + 4) {
-                       perror("tftpd: send");
-                       return;
+
+send_data:
+               if (send(peer, dp, size + 4, 0) != size + 4) {
+                       syslog(LOG_ERR, "tftpd: write: %m\n");
+                       goto abort;
                }
                }
-               do {
-                       alarm(rexmtval);
-                       n = recv(peer, buf, sizeof (buf), 0);
+               read_ahead(file, pf->f_convert);
+               for ( ; ; ) {
+                       alarm(rexmtval);        /* read the ack */
+                       n = recv(peer, ackbuf, sizeof (ackbuf), 0);
                        alarm(0);
                        if (n < 0) {
                        alarm(0);
                        if (n < 0) {
-                               perror("tftpd: recv");
-                               return;
+                               syslog(LOG_ERR, "tftpd: read: %m\n");
+                               goto abort;
+                       }
+                       ap->th_opcode = ntohs((u_short)ap->th_opcode);
+                       ap->th_block = ntohs((u_short)ap->th_block);
+
+                       if (ap->th_opcode == ERROR)
+                               goto abort;
+                       
+                       if (ap->th_opcode == ACK) {
+                               if (ap->th_block == block) {
+                                       break;
+                               }
+                               /* Re-synchronize with the other side */
+                               (void) synchnet(peer);
+                               if (ap->th_block == (block -1)) {
+                                       goto send_data;
+                               }
                        }
                        }
-                       tp->th_opcode = ntohs((u_short)tp->th_opcode);
-                       tp->th_block = ntohs((u_short)tp->th_block);
-                       if (tp->th_opcode == ERROR)
-                               return;
-               } while (tp->th_opcode != ACK || tp->th_block != block);
+
+               }
                block++;
        } while (size == SEGSIZE);
                block++;
        } while (size == SEGSIZE);
+abort:
+       (void) fclose(file);
 }
 
 }
 
+justquit()
+{
+       exit(0);
+}
+
+
 /*
  * Receive a file.
  */
 recvfile(pf)
 /*
  * Receive a file.
  */
 recvfile(pf)
-       struct format *pf;
+       struct formats *pf;
 {
 {
-       register struct tftphdr *tp;
+       struct tftphdr *dp, *w_init();
+       register struct tftphdr *ap;    /* ack buffer */
        register int block = 0, n, size;
 
        signal(SIGALRM, timer);
        register int block = 0, n, size;
 
        signal(SIGALRM, timer);
-       tp = (struct tftphdr *)buf;
+       dp = w_init();
+       ap = (struct tftphdr *)ackbuf;
        do {
                timeout = 0;
        do {
                timeout = 0;
-               tp->th_opcode = htons((u_short)ACK);
-               tp->th_block = htons((u_short)block);
+               ap->th_opcode = htons((u_short)ACK);
+               ap->th_block = htons((u_short)block);
                block++;
                (void) setjmp(timeoutbuf);
                block++;
                (void) setjmp(timeoutbuf);
-               if (send(peer, buf, 4, 0) != 4) {
-                       perror("tftpd: send");
+send_ack:
+               if (send(peer, ackbuf, 4, 0) != 4) {
+                       syslog(LOG_ERR, "tftpd: write: %m\n");
                        goto abort;
                }
                        goto abort;
                }
-               do {
+               write_behind(file, pf->f_convert);
+               for ( ; ; ) {
                        alarm(rexmtval);
                        alarm(rexmtval);
-                       n = recv(peer, buf, sizeof (buf), 0);
+                       n = recv(peer, dp, PKTSIZE, 0);
                        alarm(0);
                        alarm(0);
-                       if (n < 0) {
-                               perror("tftpd: recv");
+                       if (n < 0) {            /* really? */
+                               syslog(LOG_ERR, "tftpd: read: %m\n");
                                goto abort;
                        }
                                goto abort;
                        }
-                       tp->th_opcode = ntohs((u_short)tp->th_opcode);
-                       tp->th_block = ntohs((u_short)tp->th_block);
-                       if (tp->th_opcode == ERROR)
+                       dp->th_opcode = ntohs((u_short)dp->th_opcode);
+                       dp->th_block = ntohs((u_short)dp->th_block);
+                       if (dp->th_opcode == ERROR)
                                goto abort;
                                goto abort;
-               } while (tp->th_opcode != DATA || block != tp->th_block);
-               size = write(fd, tp->th_data, n - 4);
-               if (size < 0) {
-                       nak(errno + 100);
+                       if (dp->th_opcode == DATA) {
+                               if (dp->th_block == block) {
+                                       break;   /* normal */
+                               }
+                               /* Re-synchronize with the other side */
+                               (void) synchnet(peer);
+                               if (dp->th_block == (block-1))
+                                       goto send_ack;          /* rexmit */
+                       }
+               }
+               /*  size = write(file, dp->th_data, n - 4); */
+               size = writeit(file, &dp, n - 4, pf->f_convert);
+               if (size != (n-4)) {                    /* ahem */
+                       if (size < 0) nak(errno + 100);
+                       else nak(ENOSPACE);
                        goto abort;
                }
        } while (size == SEGSIZE);
                        goto abort;
                }
        } while (size == SEGSIZE);
+       write_behind(file, pf->f_convert);
+       (void) fclose(file);            /* close data file */
+
+       ap->th_opcode = htons((u_short)ACK);    /* send the "final" ack */
+       ap->th_block = htons((u_short)(block));
+       (void) send(peer, ackbuf, 4, 0);
+
+       signal(SIGALRM, justquit);      /* just quit on timeout */
+       alarm(rexmtval);
+       n = recv(peer, buf, sizeof (buf), 0); /* normally times out and quits */
+       alarm(0);
+       if (n >= 4 &&                   /* if read some data */
+           dp->th_opcode == DATA &&    /* and got a data block */
+           block == dp->th_block) {    /* then my last ack was lost */
+               (void) send(peer, ackbuf, 4, 0);     /* resend final ack */
+       }
 abort:
 abort:
-       tp->th_opcode = htons((u_short)ACK);
-       tp->th_block = htons((u_short)(block));
-       (void) send(peer, buf, 4, 0);
+       return;
 }
 
 struct errmsg {
 }
 
 struct errmsg {
@@ -321,13 +442,14 @@ nak(error)
        for (pe = errmsgs; pe->e_code >= 0; pe++)
                if (pe->e_code == error)
                        break;
        for (pe = errmsgs; pe->e_code >= 0; pe++)
                if (pe->e_code == error)
                        break;
-       if (pe->e_code < 0)
+       if (pe->e_code < 0) {
                pe->e_msg = sys_errlist[error - 100];
                pe->e_msg = sys_errlist[error - 100];
+               tp->th_code = EUNDEF;   /* set 'undef' errorcode */
+       }
        strcpy(tp->th_msg, pe->e_msg);
        length = strlen(pe->e_msg);
        tp->th_msg[length] = '\0';
        length += 5;
        if (send(peer, buf, length, 0) != length)
        strcpy(tp->th_msg, pe->e_msg);
        length = strlen(pe->e_msg);
        tp->th_msg[length] = '\0';
        length += 5;
        if (send(peer, buf, length, 0) != length)
-               perror("nak");
-       exit(1);
+               syslog(LOG_ERR, "nak: %m\n");
 }
 }