 /*
  * vdev_netlink.c
  * Copyright (C) 2025  Aitor C.Z. <aitor_czr@gnuinos.org>
  * 
  * This program is free software: you can redistribute it and/or modify it
  * under the terms of the GNU General Public License as published by the
  * Free Software Foundation, either version 3 of the License, or
  * (at your option) any later version.
  * 
  * This program is distributed in the hope that it will be useful, but
  * WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  * See the GNU General Public License for more details.
  * 
  * You should have received a copy of the GNU General Public License along
  * with this program.  If not, see <http://www.gnu.org/licenses/>.
  * 
  * See the COPYING file.
  */

#include "libvdev/sglib.h"
#include "libvdev/util.h"
#include "libvdev/misc.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <stdbool.h>
#include <mntent.h>
#include <sys/stat.h>

#include <poll.h>
#include <linux/netlink.h>
#include <sys/socket.h>
#include <fcntl.h>
#include <sys/types.h>
#include <signal.h>

#define VDEV_LINUX_NETLINK_BUF_MAX 4097

typedef char *cstr;

// prototypes
SGLIB_DEFINE_VECTOR_PROTOTYPES (cstr);
SGLIB_DEFINE_VECTOR_FUNCTIONS (cstr);

// connection to the linux kernel for hotplug
struct vdev_linux_context {
    // netlink address
    struct sockaddr_nl nl_addr;
   
    // poll on the netlink socket
    struct pollfd pfd;
};

static void vdev_linux_context_free(struct vdev_linux_context *ctx)
{
    if (ctx != NULL) {
        if (ctx->pfd.fd >= 0) {
            close (ctx->pfd.fd);
            ctx->pfd.fd = -1;
        }
    }
}

int main(int argc, char **argv)
{
    int rc;
    char *buf = NULL;
    ssize_t nr = 0;
    struct vdev_linux_context ctx;
    const char *progname = "vdev_netlink";
   
    /* netlink */
    ssize_t len = 0;
    char uevent_buf[VDEV_LINUX_NETLINK_BUF_MAX];
    char cbuf[CMSG_SPACE(sizeof(struct ucred))];
    struct cmsghdr *chdr = NULL;
    struct ucred *cred = NULL;
    struct msghdr hdr;
    struct iovec iov;
    struct sockaddr_nl cnls;
    size_t slen = 128 * 1024 * 1024;
    int so_passcred_enable = 1;
   
    send_to_background();
   
    /* Netlink phase: wait for new uevent files */
   
    memset(&hdr, 0, sizeof(struct msghdr));

    while (1) {
        chdr = NULL;
        cred = NULL;

        ctx.nl_addr.nl_family = AF_NETLINK;
        ctx.nl_addr.nl_pid = getpid();
        ctx.nl_addr.nl_groups = NETLINK_KOBJECT_UEVENT;
      
        ctx.pfd.fd = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_KOBJECT_UEVENT);
        if (ctx.pfd.fd < 0) {
            rc = -errno;
            fprintf(stderr, "socket(PF_NETLINK) rc = %d\n", rc);
            vdev_linux_context_free (&ctx);
            return rc;
        }
      
        ctx.pfd.events = POLLIN;
      
        // big receive buffer, if running as root 
        if (geteuid() == 0) {
            rc = setsockopt(ctx.pfd.fd, SOL_SOCKET, SO_RCVBUFFORCE, &slen, sizeof(slen));
            if (rc < 0) {
                rc = -errno;
                fprintf(stderr, "setsockopt(SO_RCVBUFFORCE) rc = %d\n", rc);
                vdev_linux_context_free (&ctx);
                return rc;
            }
        }
      
        // check credentials of message--only root should be able talk to us
        rc = setsockopt(ctx.pfd.fd, SOL_SOCKET, SO_PASSCRED, &so_passcred_enable, sizeof(so_passcred_enable));
        if (rc < 0) {
            rc = -errno;
            fprintf(stderr, "setsockopt(SO_PASSCRED) rc = %d\n", rc);
            vdev_linux_context_free (&ctx);
            return rc;
        }
     
        // bind to the address
        rc = bind(ctx.pfd.fd, (struct sockaddr*)&ctx.nl_addr, sizeof(struct sockaddr_nl));
        if (rc != 0) {
            rc = -errno;
            fprintf(stderr, "bind(%d) rc = %d\n", ctx.pfd.fd, rc);
            vdev_linux_context_free (&ctx);
            return rc;
        }
  
        // next event (wait forever)
        // NOTE: this is a cancellation point!
        rc = poll(&ctx.pfd, 1, -1);
   
        // get the event 
        iov.iov_base = uevent_buf;
        iov.iov_len = VDEV_LINUX_NETLINK_BUF_MAX;
        
        hdr.msg_iov = &iov;
        hdr.msg_iovlen = 1;
   
        // get control-plane messages
        hdr.msg_control = cbuf;
        hdr.msg_controllen = sizeof(cbuf);
   
        hdr.msg_name = &cnls;
        hdr.msg_namelen = sizeof(cnls);

        // get the event 
        len = recvmsg(ctx.pfd.fd, &hdr, 0);
        if (len < 0) {
            rc = -errno;
            fprintf(stderr, "FATAL: recvmsg(%d) rc = %d\n", ctx.pfd.fd, rc);
            vdev_linux_context_free (&ctx);
            return rc;
        }
   
        // big enough?
        if (len < 32 || len >= VDEV_LINUX_NETLINK_BUF_MAX) {
            fprintf(stderr, "Netlink message is %zd bytes; ignoring...\n", len);
            vdev_linux_context_free (&ctx);
            return -EAGAIN;
        }
   
        // control message, for credentials
        chdr = CMSG_FIRSTHDR(&hdr);
        if (chdr == NULL || chdr->cmsg_type != SCM_CREDENTIALS) {
            fprintf(stderr, "%s", "Netlink message has no credentials\n");
            vdev_linux_context_free (&ctx);
            return -EAGAIN;
        }
   
        // get the credentials
        cred = (struct ucred *)CMSG_DATA(chdr);
   
        // if not root, ignore 
        if (cred->uid != 0) {
            fprintf(stderr, "Ignoring message from non-root ID %d\n", cred->uid);
            vdev_linux_context_free (&ctx);
            return -EAGAIN;
        }
   
        // if udev, ignore... they are user space /dev events, and we are 
        // interested in driver core uevents that occur in the kernel space only 
        if (memcmp(uevent_buf, "libudev", 8) == 0) {
            // message from udev; ignore 
            fprintf(stderr, "%s", "Ignoring libudev message\n");
            vdev_linux_context_free (&ctx);
            return -EAGAIN;
        }
   
        // kernel messages don't come from userspace 
        if (cnls.nl_pid > 0) {
            // from userspace???
            fprintf(stderr, "Ignoring message from PID %d\n", (int)cnls.nl_pid);
            vdev_linux_context_free (&ctx);
            return -EAGAIN;
        }
   
        if (strstr(uevent_buf, "@/") == NULL) {
            // invalid header 
            fprintf(stderr, "%s", "invalid message header: missing '@' directive");
            vdev_linux_context_free (&ctx);
        return -EBADMSG;
        } 
   
        for (unsigned int i = 0; i < len;) {
            if (!strncmp(uevent_buf+i, "MODALIAS=", 9)) {
                FILE *pfin = NULL;
                pid_t pid;
                int wstatus;
                char cmd[1024]={0};
                char *tmp = rindex(uevent_buf + i, '=') + 1;
     
                strncpy_t (cmd, sizeof(cmd), "/sbin/modprobe -b -q -- \"", strlen("/sbin/modprobe -b -q -- \""));
                strcat(cmd, tmp);
                strcat(cmd, "\"");
                
                pfin = epopen(cmd, &pid);
                if (pfin) { 
                    fclose(pfin);
                    waitpid(pid, &wstatus, 0);
                }
                /*
                if (wstatus != 0)
                    fprintf(stderr, "Modprobe failed: %s\n", tmp);
                */
            }
            i += strlen(uevent_buf + i) + 1;
        }
        vdev_linux_context_free (&ctx);
    } // while
	
    return 0;
}
