check socket creation errors against PGINVALID_SOCKET
authorBruce Momjian <[email protected]>
Wed, 16 Apr 2014 14:45:48 +0000 (10:45 -0400)
committerBruce Momjian <[email protected]>
Wed, 16 Apr 2014 14:45:48 +0000 (10:45 -0400)
Previously, in some places, socket creation errors were checked for
negative values, which is not true for Windows because sockets are
unsigned.  This masked socket creation errors on Windows.

Backpatch through 9.0.  8.4 doesn't have the infrastructure to fix this.

src/backend/libpq/auth.c
src/backend/libpq/ip.c
src/backend/libpq/pqcomm.c
src/backend/port/win32/socket.c
src/backend/postmaster/postmaster.c
src/interfaces/libpq/fe-connect.c
src/interfaces/libpq/libpq-int.h

index f6af892b7ab76110770becd3f9b150e7aaa6fcd2..d2f8e8d4582fa619f02b80da9f20cff739df1802 100644 (file)
@@ -1677,7 +1677,7 @@ ident_inet(hbaPort *port)
 
    sock_fd = socket(ident_serv->ai_family, ident_serv->ai_socktype,
                     ident_serv->ai_protocol);
-   if (sock_fd < 0)
+   if (sock_fd == PGINVALID_SOCKET)
    {
        ereport(LOG,
                (errcode_for_socket_access(),
@@ -1757,7 +1757,7 @@ ident_inet(hbaPort *port)
                    ident_response)));
 
 ident_inet_done:
-   if (sock_fd >= 0)
+   if (sock_fd != PGINVALID_SOCKET)
        closesocket(sock_fd);
    pg_freeaddrinfo_all(remote_addr.addr.ss_family, ident_serv);
    pg_freeaddrinfo_all(local_addr.addr.ss_family, la);
@@ -2580,7 +2580,7 @@ CheckRADIUSAuth(Port *port)
    packet->length = htons(packet->length);
 
    sock = socket(serveraddrs[0].ai_family, SOCK_DGRAM, 0);
-   if (sock < 0)
+   if (sock == PGINVALID_SOCKET)
    {
        ereport(LOG,
                (errmsg("could not create RADIUS socket: %m")));
index 7a0ed4ed4ac0f58087b43e355b508846ff8ca081..6dcaede083aa4d8b3363b99ec66af2eda9af3047 100644 (file)
@@ -547,7 +547,7 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
    int         error;
 
    sock = WSASocket(AF_INET, SOCK_DGRAM, 0, 0, 0, 0);
-   if (sock == SOCKET_ERROR)
+   if (sock == INVALID_SOCKET)
        return -1;
 
    while (n_ii < 1024)
@@ -670,7 +670,7 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
                total;
 
    sock = socket(AF_INET, SOCK_DGRAM, 0);
-   if (sock == -1)
+   if (sock == PGINVALID_SOCKET)
        return -1;
 
    while (n_buffer < 1024 * 100)
@@ -711,7 +711,7 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
 #ifdef HAVE_IPV6
    /* We'll need an IPv6 socket too for the SIOCGLIFNETMASK ioctls */
    sock6 = socket(AF_INET6, SOCK_DGRAM, 0);
-   if (sock6 == -1)
+   if (sock6 == PGINVALID_SOCKET)
    {
        free(buffer);
        close(sock);
@@ -788,10 +788,10 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
    char       *ptr,
               *buffer = NULL;
    size_t      n_buffer = 1024;
-   int         sock;
+   pgsocket    sock;
 
    sock = socket(AF_INET, SOCK_DGRAM, 0);
-   if (sock == -1)
+   if (sock == PGINVALID_SOCKET)
        return -1;
 
    while (n_buffer < 1024 * 100)
index c79c846a66006103337d93bd0785dc08ec5a4885..2a1fc145ee6c88ac24c2d9c0105e0d6d97bbddfd 100644 (file)
@@ -363,7 +363,7 @@ StreamServerPort(int family, char *hostName, unsigned short portNumber,
                break;
        }
 
-       if ((fd = socket(addr->ai_family, SOCK_STREAM, 0)) < 0)
+       if ((fd = socket(addr->ai_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
        {
            ereport(LOG,
                    (errcode_for_socket_access(),
@@ -606,7 +606,7 @@ StreamConnection(pgsocket server_fd, Port *port)
    port->raddr.salen = sizeof(port->raddr.addr);
    if ((port->sock = accept(server_fd,
                             (struct sockaddr *) & port->raddr.addr,
-                            &port->raddr.salen)) < 0)
+                            &port->raddr.salen)) == PGINVALID_SOCKET)
    {
        ereport(LOG,
                (errcode_for_socket_access(),
index 9e849271fe67b01d814e9033b3f12019d5fa31d7..b75067d9ca67cf81f97f552c943763452ebcb780 100644 (file)
@@ -132,7 +132,7 @@ int
 pgwin32_waitforsinglesocket(SOCKET s, int what, int timeout)
 {
    static HANDLE waitevent = INVALID_HANDLE_VALUE;
-   static SOCKET current_socket = -1;
+   static SOCKET current_socket = INVALID_SOCKET;
    static int  isUDP = 0;
    HANDLE      events[2];
    int         r;
index 29b519813fec72acb2226cfd40cb5d4d9ee02470..1fabbce517fbecdcc8d5f721e22295d48bdbf100 100644 (file)
@@ -2013,7 +2013,7 @@ ConnCreate(int serverFd)
 
    if (StreamConnection(serverFd, port) != STATUS_OK)
    {
-       if (port->sock >= 0)
+       if (port->sock != PGINVALID_SOCKET)
            StreamClose(port->sock);
        ConnFree(port);
        return NULL;
index 448d975200327f46de091239387057845e7d16bf..8e5e91276773c98a07671681ce2d88470af59c64 100644 (file)
@@ -1585,8 +1585,23 @@ keep_going:                      /* We will come back to here until there is
                    conn->raddr.salen = addr_cur->ai_addrlen;
 
                    /* Open a socket */
-                   conn->sock = socket(addr_cur->ai_family, SOCK_STREAM, 0);
-                   if (conn->sock < 0)
+                   {
+                       /*
+                        * While we use 'pgsocket' as the socket type in the
+                        * backend, we use 'int' for libpq socket values.
+                        * This requires us to map PGINVALID_SOCKET to -1
+                        * on Windows.
+                        * See http://msdn.microsoft.com/en-us/library/windows/desktop/ms740516%28v=vs.85%29.aspx
+                        */
+                       pgsocket sock = socket(addr_cur->ai_family, SOCK_STREAM, 0);
+#ifdef WIN32
+                       if (sock == PGINVALID_SOCKET)
+                           conn->sock = -1;
+                       else
+#endif
+                           conn->sock = sock;
+                   }
+                   if (conn->sock == -1)
                    {
                        /*
                         * ignore socket() failure if we have more addresses
@@ -3123,7 +3138,7 @@ internal_cancel(SockAddr *raddr, int be_pid, int be_key,
                char *errbuf, int errbufsize)
 {
    int         save_errno = SOCK_ERRNO;
-   int         tmpsock = -1;
+   pgsocket    tmpsock = PGINVALID_SOCKET;
    char        sebuf[256];
    int         maxlen;
    struct
@@ -3136,7 +3151,7 @@ internal_cancel(SockAddr *raddr, int be_pid, int be_key,
     * We need to open a temporary connection to the postmaster. Do this with
     * only kernel calls.
     */
-   if ((tmpsock = socket(raddr->addr.ss_family, SOCK_STREAM, 0)) < 0)
+   if ((tmpsock = socket(raddr->addr.ss_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
    {
        strlcpy(errbuf, "PQcancel() -- socket() failed: ", errbufsize);
        goto cancel_errReturn;
@@ -3207,7 +3222,7 @@ cancel_errReturn:
                maxlen);
        strcat(errbuf, "\n");
    }
-   if (tmpsock >= 0)
+   if (tmpsock != PGINVALID_SOCKET)
        closesocket(tmpsock);
    SOCK_ERRNO_SET(save_errno);
    return FALSE;
@@ -4620,6 +4635,15 @@ PQerrorMessage(const PGconn *conn)
    return conn->errorMessage.data;
 }
 
+/*
+ * In Windows, socket values are unsigned, and an invalid socket value
+ * (INVALID_SOCKET) is ~0, which equals -1 in comparisons (with no compiler
+ * warning). Ideally we would return an unsigned value for PQsocket() on
+ * Windows, but that would cause the function's return value to differ from
+ * Unix, so we just return -1 for invalid sockets.
+ * http://msdn.microsoft.com/en-us/library/windows/desktop/cc507522%28v=vs.85%29.aspx
+ * http://stackoverflow.com/questions/10817252/why-is-invalid-socket-defined-as-0-in-winsock2-h-c
+ */
 int
 PQsocket(const PGconn *conn)
 {
index 0c2966915967bf68ea3f79e7d382f951bdb8de99..3fd217f6435ded7d165b8901ca7b207b949d8be5 100644 (file)
@@ -348,6 +348,7 @@ struct pg_conn
    PGnotify   *notifyTail;     /* newest unreported Notify msg */
 
    /* Connection data */
+   /* See PQconnectPoll() for how we use 'int' and not 'pgsocket'. */
    int         sock;           /* Unix FD for socket, -1 if not connected */
    SockAddr    laddr;          /* Local address */
    SockAddr    raddr;          /* Remote address */