diff options
Diffstat (limited to 'fs/smbfs/sock.c')
-rw-r--r-- | fs/smbfs/sock.c | 114 |
1 files changed, 73 insertions, 41 deletions
diff --git a/fs/smbfs/sock.c b/fs/smbfs/sock.c index ca6d8c269..4d85b8e66 100644 --- a/fs/smbfs/sock.c +++ b/fs/smbfs/sock.c @@ -15,6 +15,7 @@ #include <linux/net.h> #include <linux/mm.h> #include <linux/netdevice.h> +#include <net/scm.h> #include <net/ip.h> #include <linux/smb.h> @@ -26,42 +27,53 @@ static int _recvfrom(struct socket *sock, unsigned char *ubuf, int size, - int noblock, unsigned flags, struct sockaddr_in *sa, int *addr_len) + unsigned flags) { struct iovec iov; struct msghdr msg; + struct scm_cookie scm; - iov.iov_base = ubuf; - iov.iov_len = size; - - msg.msg_name = (void *) sa; + msg.msg_name = NULL; msg.msg_namelen = 0; - if (addr_len) - msg.msg_namelen = *addr_len; - msg.msg_control = NULL; msg.msg_iov = &iov; msg.msg_iovlen = 1; - - return sock->ops->recvmsg(sock, &msg, size, noblock, flags, addr_len); + msg.msg_control = NULL; + iov.iov_base = ubuf; + iov.iov_len = size; + + memset(&scm, 0,sizeof(scm)); + size=sock->ops->recvmsg(sock, &msg, size, flags, &scm); + if(size>=0) + scm_recv(sock,&msg,&scm,flags); + return size; } static int -_send(struct socket *sock, const void *buff, int len, - int nonblock, unsigned flags) +_send(struct socket *sock, const void *buff, int len) { struct iovec iov; struct msghdr msg; - - iov.iov_base = (void *) buff; - iov.iov_len = len; + struct scm_cookie scm; + int err; msg.msg_name = NULL; msg.msg_namelen = 0; - msg.msg_control = NULL; msg.msg_iov = &iov; msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + + iov.iov_base = (void *)buff; + iov.iov_len = len; + + msg.msg_flags = 0; - return sock->ops->sendmsg(sock, &msg, len, nonblock, flags); + err = scm_send(sock, &msg, &scm); + if (err < 0) + return err; + err = sock->ops->sendmsg(sock, &msg, len, &scm); + scm_destroy(&scm); + return err; } static void @@ -78,14 +90,14 @@ smb_data_callback(struct sock *sk, int len) fs = get_fs(); set_fs(get_ds()); - result = _recvfrom(sock, (void *) peek_buf, 1, 1, - MSG_PEEK, NULL, NULL); + result = _recvfrom(sock, (void *) peek_buf, 1, + MSG_PEEK | MSG_DONTWAIT); while ((result != -EAGAIN) && (peek_buf[0] == 0x85)) { /* got SESSION KEEP ALIVE */ - result = _recvfrom(sock, (void *) peek_buf, - 4, 1, 0, NULL, NULL); + result = _recvfrom(sock, (void *) peek_buf, 4, + MSG_DONTWAIT); DDPRINTK("smb_data_callback:" " got SESSION KEEP ALIVE\n"); @@ -94,9 +106,8 @@ smb_data_callback(struct sock *sk, int len) { break; } - result = _recvfrom(sock, (void *) peek_buf, - 1, 1, MSG_PEEK, - NULL, NULL); + result = _recvfrom(sock, (void *) peek_buf, 1, + MSG_PEEK | MSG_DONTWAIT); } set_fs(fs); @@ -132,7 +143,7 @@ smb_catch_keepalive(struct smb_server *server) server->data_ready = NULL; return -EINVAL; } - sk = (struct sock *) (sock->data); + sk = sock->sk; if (sk == NULL) { @@ -178,7 +189,7 @@ smb_dont_catch_keepalive(struct smb_server *server) printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n"); return -EINVAL; } - sk = (struct sock *) (sock->data); + sk = sock->sk; if (sk == NULL) { @@ -216,8 +227,12 @@ smb_send_raw(struct socket *sock, unsigned char *source, int length) { result = _send(sock, (void *) (source + already_sent), - length - already_sent, 0, 0); + length - already_sent); + if (result == 0) + { + return -EIO; + } if (result < 0) { DPRINTK("smb_send_raw: sendto error = %d\n", @@ -239,9 +254,12 @@ smb_receive_raw(struct socket *sock, unsigned char *target, int length) { result = _recvfrom(sock, (void *) (target + already_read), - length - already_read, 0, 0, - NULL, NULL); + length - already_read, 0); + if (result == 0) + { + return -EIO; + } if (result < 0) { DPRINTK("smb_receive_raw: recvfrom error = %d\n", @@ -369,7 +387,6 @@ smb_receive_trans2(struct smb_server *server, int total_data = 0; int total_param = 0; int result; - unsigned char *inbuf = server->packet; unsigned char *rcv_buf; int buf_len; int data_len = 0; @@ -385,8 +402,8 @@ smb_receive_trans2(struct smb_server *server, *ldata = *lparam = 0; return 0; } - total_data = WVAL(inbuf, smb_tdrcnt); - total_param = WVAL(inbuf, smb_tprcnt); + total_data = WVAL(server->packet, smb_tdrcnt); + total_param = WVAL(server->packet, smb_tprcnt); DDPRINTK("smb_receive_trans2: td=%d,tp=%d\n", total_data, total_param); @@ -411,6 +428,8 @@ smb_receive_trans2(struct smb_server *server, while (1) { + unsigned char *inbuf = server->packet; + if (WVAL(inbuf, smb_prdisp) + WVAL(inbuf, smb_prcnt) > total_param) { @@ -480,6 +499,8 @@ smb_receive_trans2(struct smb_server *server, return result; } +extern struct net_proto_family inet_family_ops; + int smb_release(struct smb_server *server) { @@ -498,8 +519,8 @@ smb_release(struct smb_server *server) is nothing behind it, so I set it to SS_UNCONNECTED. */ sock->state = SS_UNCONNECTED; - result = sock->ops->create(sock, 0); - DPRINTK("smb_release: sock->ops->create = %d\n", result); + result = inet_family_ops.create(sock, 0); + DPRINTK("smb_release: inet_create = %d\n", result); return result; } @@ -588,6 +609,8 @@ smb_send_trans2(struct smb_server *server, __u16 trans2_command, int lparam, unsigned char *param) { struct socket *sock = server_sock(server); + struct scm_cookie scm; + int err; /* I know the following is very ugly, but I want to build the smb packet as efficiently as possible. */ @@ -632,6 +655,15 @@ smb_send_trans2(struct smb_server *server, __u16 trans2_command, *p++ = 'D'; /* this was added because OS/2 does it */ *p++ = ' '; + + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_iov = iov; + msg.msg_iovlen = 4; + msg.msg_flags = 0; + iov[0].iov_base = (void *) server->packet; iov[0].iov_len = oparam; iov[1].iov_base = (param == NULL) ? padding : param; @@ -641,13 +673,13 @@ smb_send_trans2(struct smb_server *server, __u16 trans2_command, iov[3].iov_base = (data == NULL) ? padding : data; iov[3].iov_len = ldata; - msg.msg_name = NULL; - msg.msg_namelen = 0; - msg.msg_control = NULL; - msg.msg_iov = iov; - msg.msg_iovlen = 4; - - return sock->ops->sendmsg(sock, &msg, packet_length, 0, 0); + err = scm_send(sock, &msg, &scm); + if (err < 0) + return err; + + err = sock->ops->sendmsg(sock, &msg, packet_length, &scm); + scm_destroy(&scm); + return err; } /* |