Refactor connection polling

refactor
PgSocks 2 years ago
parent 6eb0d092e6
commit a283cad9ed

@ -13,7 +13,7 @@ struct Client {
struct Connection* sdk_conn; struct Connection* sdk_conn;
struct Connection* dev_conn; struct Connection* dev_conn;
}; };
struct Connection *conns[2]; struct Connection* conns[2];
}; };
}; };

@ -12,6 +12,8 @@
#include <fcntl.h> #include <fcntl.h>
#include <stdio.h> #include <stdio.h>
// The greated file descriptor is needed for polling the sockets.
// It needs to be global for the whole process.
int max_fd = -1; int max_fd = -1;
struct Connection* struct Connection*
@ -54,6 +56,57 @@ connection_new(unsigned int source_port, const char* source_ip, unsigned int des
return conn; return conn;
} }
struct Connection*
connection_poll_ready(struct Client* client) {
// Add all the connections' socket file descriptors to a watch list
fd_set read_fds;
FD_ZERO(&read_fds);
for(int i = 0; i < 2; i++)
if(client->conns[i])
FD_SET(client->conns[i]->sockfd, &read_fds);
//struct timeval timeout = {0, 0};
int result = select(max_fd + 1, &read_fds, NULL, NULL, NULL);
// Check for socket polling errors
if(result < 0) {
perror("message polling failed");
exit(EXIT_FAILURE);
}
// Return a null connection if nothing was received on any of them
if (result == 0)
return NULL;
// Return the first connection with something to read
for(int i = 0; i < 2; i++)
if(FD_ISSET(client->conns[i]->sockfd, &read_fds))
return client->conns[i];
// Return a null connection if somehow none of them have anything to read
return NULL;
}
void
connection_read(struct Connection* conn, union Message* resp) {
int recvb = recvfrom(conn->sockfd, resp, sizeof(union Message), 0, (struct sockaddr*)&conn->remote_addr, &conn->addrlen);
// Check for socket read errors
if(recvb < 0) {
perror("reading socket failed");
exit(EXIT_FAILURE);
}
// Check for message errors
if(message_validate(resp) != MESSAGEERR_NONE) {
perror("invalid message");
exit(EXIT_FAILURE);
}
}
void void
req_send(struct Connection* conn, union Request* req, size_t length) { req_send(struct Connection* conn, union Request* req, size_t length) {
sendto(conn->sockfd, req, length, 0, (struct sockaddr*)&conn->remote_addr, conn->addrlen); sendto(conn->sockfd, req, length, 0, (struct sockaddr*)&conn->remote_addr, conn->addrlen);

@ -6,11 +6,6 @@
#include <stddef.h> #include <stddef.h>
#include <netinet/in.h> #include <netinet/in.h>
// The greated file descriptor is needed for polling the sockets.
// It needs to be global for the whole process.
// TODO: Should probably just move the polling function to connection.c
extern int max_fd;
struct Connection { struct Connection {
int sockfd; int sockfd;
socklen_t addrlen; socklen_t addrlen;
@ -20,6 +15,12 @@ struct Connection {
struct Connection* struct Connection*
connection_new(unsigned int source_port, const char* source_ip, unsigned int dest_port, const char* dest_ip); connection_new(unsigned int source_port, const char* source_ip, unsigned int dest_port, const char* dest_ip);
struct Connection*
connection_poll_ready(struct Client* client);
void
connection_read(struct Connection* connection, union Message* resp);
void void
req_finalize(struct Client* client, uint8_t cmdset, uint8_t cmdid, size_t length, union Request* req); req_finalize(struct Client* client, uint8_t cmdset, uint8_t cmdid, size_t length, union Request* req);

@ -5,11 +5,7 @@
#include "connection.h" #include "connection.h"
#include <stdlib.h> #include <stdlib.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <netinet/in.h>
#include <string.h> #include <string.h>
#include <stdio.h>
Client client_new() { Client client_new() {
struct Client* client = malloc(sizeof(struct Client)); struct Client* client = malloc(sizeof(struct Client));
@ -30,48 +26,8 @@ void poll_message(Client client, union Message* resp) {
memset(resp, 0, sizeof(union Message)); memset(resp, 0, sizeof(union Message));
// Poll for messages struct Connection* conn = connection_poll_ready(client);
static struct timeval timeout = {0, 0}; if(conn)
fd_set read_fds; connection_read(conn, resp);
FD_ZERO(&read_fds);
FD_SET(client->sdk_conn->sockfd, &read_fds);
FD_SET(client->dev_conn->sockfd, &read_fds);
int result = select(max_fd + 1, &read_fds, NULL, NULL, NULL);
// Check for socket polling errors
if(result < 0) {
perror("message polling failed");
exit(EXIT_FAILURE);
}
// Skip if nothing was received yet
// TODO: Make a static "empty" message or something
if (result == 0)
return;
// Read a message from the sockets
for(int i = 0; i < 2; i++)
{
if(!FD_ISSET(client->conns[i]->sockfd, &read_fds))
continue;
int recvb = recvfrom(client->sdk_conn->sockfd, resp, sizeof(union Message), 0, (struct sockaddr*)&client->sdk_conn->remote_addr, &client->sdk_conn->addrlen);
// Check for socket read errors
if(recvb < 0) {
perror("reading socket failed");
exit(EXIT_FAILURE);
}
// Check for message errors
if(message_validate(resp) != MESSAGEERR_NONE) {
perror("invalid message");
exit(EXIT_FAILURE);
}
return;
}
} }

Loading…
Cancel
Save