#include "vmotionos_server.h"
#include "vmotionos.h"
#include "vmotionos_client.h"
#include <linux/mman.h>
#include <linux/pgtable.h>
#include <asm/pgtable_types.h>
#include <linux/mm.h>

static struct socket *udp_socket = NULL;
static struct task_struct *udp_thread = NULL;

static int udp_server_thread(void *data)
{
    struct sockaddr_in client_addr, server_addr;
    struct msghdr msg;
    struct kvec iov;
    char *buffer;
    int ret;

    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(SERVER_PORT);
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);

    /* Allocate buffer */
    buffer = kmalloc(BUFFER_SIZE, GFP_KERNEL);
    if (!buffer) {
        vmotionos_err("Server: Failed to allocate buffer\n");
        return -ENOMEM;
    }

    /* Bind socket */
    ret = kernel_bind(udp_socket, (struct sockaddr *)&server_addr, sizeof(server_addr));
    if (ret < 0) {
        vmotionos_err("Server: Failed to bind socket, error %d\n", ret);
        kfree(buffer);
        return ret;
    }
    
    vmotionos_info("Server: Listening on port %d\n", SERVER_PORT);

    while (!kthread_should_stop()) {
        /* Reset structures */
        memset(buffer, 0, BUFFER_SIZE);
        memset(&client_addr, 0, sizeof(client_addr));
        memset(&msg, 0, sizeof(msg));

        /* Setup receive parameters */
        iov.iov_base = buffer;
        iov.iov_len = BUFFER_SIZE - 1;

        msg.msg_name = &client_addr;
        msg.msg_namelen = sizeof(client_addr);

        /* Receive message */
        ret = kernel_recvmsg(udp_socket, &msg, &iov, 1, BUFFER_SIZE - 1, 0);
        if (ret > 0) {
            /* Check if this is a typed message */
            if (ret >= sizeof(struct vmotionos_message)) {
                struct vmotionos_message *typed_msg = (struct vmotionos_message *)buffer;
                
                vmotionos_debug("Server: Received message (type: %d, data_size: %zu, total: %d bytes)\n",
                               typed_msg->msg_type, typed_msg->data_size, ret);
                
                switch (typed_msg->msg_type) {
                    case MSG_TYPE_CREATE_THREAD:
                        /* Component-by-component unpacking */
                        {
                            size_t basic_size = sizeof(int) + MAX_COMM_LEN;
                            size_t regs_size = sizeof(struct vmotionos_regs);
                            size_t mm_meta_size = sizeof(struct vmotionos_mm) - (MAX_VMA_COUNT * sizeof(struct vmotionos_vma));
                            size_t files_meta_size = sizeof(struct vmotionos_files) - (MAX_FD_COUNT * sizeof(struct vmotionos_fd));
                            size_t min_size = basic_size + regs_size + mm_meta_size + files_meta_size;
                            struct vmotionos_thread *thread;
                            char *read_ptr;
                                
                            if (typed_msg->data_size >= min_size) {
                                thread = kzalloc(sizeof(struct vmotionos_thread), GFP_KERNEL);
                                if (!thread) {
                                    vmotionos_err("Server: Failed to allocate thread structure\n");
                                    break;
                                }
                                
                                read_ptr = typed_msg->data;
                                vmotionos_debug("Server: Starting component unpacking (min_size: %zu)\n", min_size);
                                
                                /* 1. Read basic info: pid + comm */
                                memcpy(&thread->pid, read_ptr, sizeof(int));
                                read_ptr += sizeof(int);
                                memcpy(thread->comm, read_ptr, MAX_COMM_LEN);
                                read_ptr += MAX_COMM_LEN;
                                vmotionos_debug("Server: Unpacked basic info - PID: %d (0x%x), COMM: %s\n", 
                                               thread->pid, thread->pid, thread->comm);
                                
                                /* 2. Read CPU registers */
                                memcpy(&thread->regs, read_ptr, regs_size);
                                read_ptr += regs_size;
                                vmotionos_debug("Server: Unpacked registers (%zu bytes)\n", regs_size);
                                
                                /* 3. Read MM metadata */
                                memcpy(&thread->mm, read_ptr, mm_meta_size);
                                read_ptr += mm_meta_size;
                                vmotionos_debug("Server: Unpacked MM metadata (%zu bytes) - VMAs: %d\n", 
                                               mm_meta_size, thread->mm.vma_count);
                                
                                /* 4. Read files metadata */
                                memcpy(&thread->files, read_ptr, files_meta_size);
                                read_ptr += files_meta_size;
                                vmotionos_debug("Server: Unpacked files metadata (%zu bytes) - FDs: %d\n", 
                                               files_meta_size, thread->files.fd_count);
                                
                                vmotionos_debug("Server: Unpacking CREATE_THREAD for '%s' with %d VMAs, %d FDs\n",
                                               thread->comm, thread->mm.vma_count, thread->files.fd_count);
                                
                                /* 5. Read actual VMAs */
                                if (thread->mm.vma_count > 0 && thread->mm.vma_count <= MAX_VMA_COUNT) {
                                    size_t vma_bytes = thread->mm.vma_count * sizeof(struct vmotionos_vma);
                                    if (read_ptr + vma_bytes <= typed_msg->data + typed_msg->data_size) {
                                        memcpy(thread->mm.vmas, read_ptr, vma_bytes);
                                        read_ptr += vma_bytes;
                                        vmotionos_debug("Server: Unpacked %d VMAs (%zu bytes)\n", 
                                                      thread->mm.vma_count, vma_bytes);
                                    } else {
                                        vmotionos_warn("Server: VMA data truncated\n");
                                    }
                                }
                                
                                /* 6. Read actual FDs */  
                                if (thread->files.fd_count > 0 && thread->files.fd_count <= MAX_FD_COUNT) {
                                    size_t fd_bytes = thread->files.fd_count * sizeof(struct vmotionos_fd);
                                    if (read_ptr + fd_bytes <= typed_msg->data + typed_msg->data_size) {
                                        memcpy(thread->files.fds, read_ptr, fd_bytes);
                                        vmotionos_debug("Server: Unpacked %d FDs (%zu bytes) - SUCCESS!\n", 
                                                      thread->files.fd_count, fd_bytes);
                                    } else {
                                        vmotionos_warn("Server: FD data truncated\n");
                                    }
                                }
                                
                                vmotionos_process_received_thread(thread);
                                kfree(thread);
                            } else {
                                vmotionos_warn("Server: CREATE_THREAD message too small: %zu bytes (min: %zu)\n",
                                              typed_msg->data_size, min_size);
                            }
                        }
                        break;
                    
                    case MSG_TYPE_DESTROY_THREAD:
                        vmotionos_info("Server: Received DESTROY_THREAD message\n");
                        break;
                        
                    case MSG_TYPE_MIGRATE_THREAD:
                        vmotionos_info("Server: Received MIGRATE_THREAD message\n");
                        break;
                        
                    case MSG_TYPE_PAGE_REQUEST:
                        {
                            struct vmotionos_page_request *page_req;
                            
                            if (typed_msg->data_size >= sizeof(struct vmotionos_page_request)) {
                                page_req = (struct vmotionos_page_request *)typed_msg->data;
                                vmotionos_debug("Server: Received PAGE_REQUEST for PID %d (0x%x), addr 0x%lx\n",
                                               page_req->source_pid, page_req->source_pid, page_req->page_addr);
                                vmotionos_debug("Server: Received from IP: %s, Port: %d\n",
                                               page_req->source_ip, page_req->source_port);
                                vmotionos_handle_page_request(page_req);
                            } else {
                                vmotionos_warn("Server: PAGE_REQUEST message too small: %zu bytes\n",
                                              typed_msg->data_size);
                            }
                        }
                        break;
                        
                    case MSG_TYPE_PAGE_RESPONSE:
                        {
                            struct vmotionos_page_response *page_resp;
                            
                            if (typed_msg->data_size >= sizeof(struct vmotionos_page_response)) {
                                page_resp = (struct vmotionos_page_response *)typed_msg->data;
                                vmotionos_debug("Server: Received PAGE_RESPONSE for PID %d, addr 0x%lx\n",
                                               page_resp->source_pid, page_resp->page_addr);
                                vmotionos_handle_page_response(page_resp);
                            } else {
                                vmotionos_warn("Server: PAGE_RESPONSE message too small: %zu bytes\n",
                                              typed_msg->data_size);
                            }
                        }
                        break;
                    
                    default:
                        vmotionos_warn("Server: Unknown message type: %d\n", typed_msg->msg_type);
                        break;
                }
            } else {
                /* Handle as simple message */
                buffer[ret] = '\0';
                vmotionos_debug("Server: Received simple message - Size: %d bytes\n", ret);
            }
        } else if (ret < 0 && ret != -EINTR) {
            vmotionos_err("Server: Receive error %d\n", ret);
        }
    }

    kfree(buffer);
    return 0;
}

int vmotionos_server_start(void)
{
    int ret;

    vmotionos_info("Server: Starting UDP server...\n");

    /* Create UDP socket */
    ret = sock_create(AF_INET, SOCK_DGRAM, IPPROTO_UDP, &udp_socket);
    if (ret < 0) {
        vmotionos_err("Server: Failed to create socket, error %d\n", ret);
        return ret;
    }

    /* Start server thread */
    udp_thread = kthread_run(udp_server_thread, NULL, "vmotionos_server");
    if (IS_ERR(udp_thread)) {
        vmotionos_err("Server: Failed to create thread\n");
        sock_release(udp_socket);
        return PTR_ERR(udp_thread);
    }

    vmotionos_info("Server: Started successfully on port %d\n", SERVER_PORT);
    return 0;
}

void vmotionos_server_stop(void)
{
    vmotionos_info("Server: Stopping server...\n");

    if (udp_thread) {
        kthread_stop(udp_thread);
        vmotionos_info("Server: Thread stopped\n");
    }

    if (udp_socket) {
        sock_release(udp_socket);
        vmotionos_info("Server: Socket released\n");
    }
}

/* Helper to get physical address by walking page tables for special regions */
static resource_size_t get_phys_addr_from_pt(struct mm_struct *mm, unsigned long addr)
{
    pgd_t *pgd;
    p4d_t *p4d;
    pud_t *pud;
    pmd_t *pmd;
    pte_t *pte;
    resource_size_t phys = 0;

    if (!mm) return 0;

    pgd = pgd_offset(mm, addr);
    if (pgd_none(*pgd) || pgd_bad(*pgd))
        goto out;

    p4d = p4d_offset(pgd, addr);
    if (p4d_none(*p4d) || p4d_bad(*p4d))
        goto out;

    pud = pud_offset(p4d, addr);
    if (pud_none(*pud) || pud_bad(*pud))
        goto out;

    pmd = pmd_offset(pud, addr);
    if (pmd_none(*pmd) || pmd_bad(*pmd))
        goto out;

    pte = pte_offset_map(pmd, addr);
    if (!pte || pte_none(*pte)) {
        if (pte) pte_unmap(pte);
        goto out;
    }

    phys = (resource_size_t)pte_pfn(*pte) << PAGE_SHIFT;
    pte_unmap(pte);

out:
    return phys;
}

/* Handle page request from destination node */
void vmotionos_handle_page_request(struct vmotionos_page_request *request)
{
    struct task_struct *source_task;
    struct vm_area_struct *vma;
    void *page_data;
    struct vmotionos_page_response *response;
    int ret;
    
    vmotionos_debug("Handling page request for PID %d, addr 0x%lx\n",
                   request->source_pid, request->page_addr);
    
    /* Validate PID range */
    if (request->source_pid <= 0 || request->source_pid > 65535) {
        vmotionos_err("Failed INVALID PID: %d (out of valid range 1-65535)\n", request->source_pid);
        vmotionos_err("This suggests data corruption during transmission\n");
        return;
    }
    
    /* Find the source task */
    rcu_read_lock();
    source_task = pid_task(find_vpid(request->source_pid), PIDTYPE_PID);
    if (!source_task) {
        rcu_read_unlock();
        vmotionos_err("Source task PID %d not found\n", request->source_pid);
        return;
    }
    
    /* Check if the address is in the specified VMA range */
    if (request->page_addr < request->vma_start || request->page_addr >= request->vma_end) {
        rcu_read_unlock();
        vmotionos_err("Page address 0x%lx not in VMA range [0x%lx-0x%lx]\n",
                      request->page_addr, request->vma_start, request->vma_end);
        return;
    }
    mmap_read_lock(source_task->mm);
    /* Find the VMA containing this address */
    vma = find_vma(source_task->mm, request->page_addr);
    if (!vma || request->page_addr < vma->vm_start || request->page_addr >= vma->vm_end) {
        mmap_read_unlock(source_task->mm);
        rcu_read_unlock();
        vmotionos_err("No VMA found for address 0x%lx\n", request->page_addr);
        return;
    }
    vmotionos_debug("Found VMA 0x%lx-0x%lx for address 0x%lx, flags=0x%lx\n",
                   vma->vm_start, vma->vm_end, request->page_addr, vma->vm_flags);
    
    /* Try to read the page data from the source process */
    page_data = kmalloc(PAGE_SIZE, GFP_KERNEL);
    if (!page_data) {
        mmap_read_unlock(source_task->mm);
        rcu_read_unlock();
        vmotionos_err("Failed to allocate page data buffer\n");
        return;
    }
    
    /* Read the page data from the source process */
    ret = access_process_vm(source_task, request->page_addr, page_data, PAGE_SIZE, 0);
    
    if (ret != PAGE_SIZE) {
        vmotionos_debug("Normal access_process_vm failed: %d bytes read (expected %lu)\n", ret, PAGE_SIZE);
        vmotionos_debug("Address: 0x%lx, VMA: 0x%lx-0x%lx, flags=0x%lx\n", 
                      request->page_addr, request->vma_start, request->vma_end, vma->vm_flags);
        
        /* Check if this is a special region that needs page table walking */
        if (ret == 0 && (vma->vm_flags & (VM_PFNMAP | VM_IO))) {
            vmotionos_debug("Attempting page table walk for PFNMAP/IO region\n");
            
            /* Try page table walk for special regions (e.g., VDSO/PFNMAP) */
            resource_size_t phys_addr = get_phys_addr_from_pt(source_task->mm, request->page_addr);
            if (phys_addr) {
                void *src = __va(phys_addr); /* Kernel direct mapping */
                if (src) {
                    memcpy(page_data, src, PAGE_SIZE);
                    ret = PAGE_SIZE; /* Mark as successful */
                    vmotionos_debug("Page table walk successful for special region (phys: 0x%llx)\n", phys_addr);
                } else {
                    vmotionos_err("Failed to map physical address 0x%llx\n", phys_addr);
                }
            } else {
                vmotionos_err("Failed to get physical address via page table walk\n");
            }
        }
        
        /* If still failed, handle as before */
        if (ret != PAGE_SIZE) {
            if (ret == 0) {
                vmotionos_err("Page appears to be unmapped or inaccessible (likely demand-paged)\n");
                vmotionos_info("Providing zero page for demand-paged memory\n");
                
                /* For demand-paged memory, provide a zero page */
                memset(page_data, 0, PAGE_SIZE);
                ret = PAGE_SIZE; /* Pretend we read successfully */
            } else if (ret < 0) {
                vmotionos_err("Error reading page: %d\n", ret);
                mmap_read_unlock(source_task->mm);
                rcu_read_unlock();
                kfree(page_data);
                return;
            } else {
                vmotionos_err("Partial read: %d bytes\n", ret);
                mmap_read_unlock(source_task->mm);
                rcu_read_unlock();
                kfree(page_data);
                return;
            }
        }
    }
    
    mmap_read_unlock(source_task->mm);
    rcu_read_unlock();
    
    vmotionos_debug("Successfully obtained page data from source process (%s method)\n", 
                   ret == PAGE_SIZE ? "direct" : "fallback");
    response = kzalloc(sizeof(struct vmotionos_page_response), GFP_KERNEL);
    if (!response) {
        vmotionos_err("Failed to allocate page response structure\n");
        kfree(page_data);
        return;
    }
    
    /* Prepare response */
    response->source_pid = request->source_pid;
    response->page_addr = request->page_addr;
    response->success = 1;
    memcpy(response->page_data, page_data, PAGE_SIZE);
    vmotionos_debug("Server: Sending response to IP: %s, Port: %d\n", request->dest_ip, request->dest_port);
    kfree(page_data);
    
    /* Send response back to destination node */
    ret = vmotionos_send_page_response(response, request->dest_ip, request->dest_port);
    if (ret < 0) {
        vmotionos_err("Failed to send page response: %d\n", ret);
    } else {
        vmotionos_debug("Page response sent successfully\n");
    }
}

/* Handle page response from source node */
void vmotionos_handle_page_response(struct vmotionos_page_response *response)
{
    void *page_data_copy;
    
    vmotionos_debug("Received page response for PID %d, addr 0x%lx, success: %d\n",
                   response->source_pid, response->page_addr, response->success);
    
    if (response->success) {
        /* Allocate memory for the page data */
        page_data_copy = kmalloc(PAGE_SIZE, GFP_KERNEL);
        if (!page_data_copy) {
            vmotionos_err("Failed to allocate memory for page data\n");
            /* Still notify the waiting fault handler with failure */
            vmotionos_set_page_response(response->page_addr, NULL, 0);
            return;
        }
        
        /* Copy the page data */
        memcpy(page_data_copy, response->page_data, PAGE_SIZE);
        vmotionos_debug("Page data received successfully (%lu bytes)\n", PAGE_SIZE);
        
        /* Notify the waiting fault handler */
        vmotionos_set_page_response(response->page_addr, page_data_copy, 1);
    } else {
        vmotionos_err("Page request failed on source node\n");
        /* Notify the waiting fault handler with failure */
        vmotionos_set_page_response(response->page_addr, NULL, 0);
    }
}

/* Send page request to source node */
int vmotionos_send_page_request(struct vmotionos_page_request *request, char *dest_ip, int dest_port)
{
    struct vmotionos_message *typed_msg;
    char *message_buffer;
    size_t total_msg_size;
    struct vmotionos_message_params msg;
    int ret;
    
    if (!request || !dest_ip) return -EINVAL;
    
    total_msg_size = sizeof(struct vmotionos_message) + sizeof(struct vmotionos_page_request);
    
    vmotionos_debug("Sending page request to %s:%d for PID %d (0x%x), addr 0x%lx\n",
                   dest_ip, dest_port, request->source_pid, request->source_pid, request->page_addr);
    
    /* Create typed message */
    message_buffer = kzalloc(total_msg_size, GFP_KERNEL);
    if (!message_buffer) {
        vmotionos_err("Failed to allocate message buffer (%zu bytes)\n", total_msg_size);
        return -ENOMEM;
    }
    
    typed_msg = (struct vmotionos_message *)message_buffer;
    typed_msg->msg_type = MSG_TYPE_PAGE_REQUEST;
    typed_msg->data_size = sizeof(struct vmotionos_page_request);
    
    /* Copy page request data */
    memcpy(typed_msg->data, request, sizeof(struct vmotionos_page_request));
    
    /* Send UDP message */
    msg.dest_ip = dest_ip;
    msg.dest_port = dest_port;
    msg.message = message_buffer;
    msg.msg_len = total_msg_size;
    
    ret = vmotionos_send_message(&msg);
    if (ret < 0) {
        vmotionos_err("Failed to send PAGE_REQUEST message: %d\n", ret);
    } else {
        vmotionos_debug("Successfully sent PAGE_REQUEST message\n");
    }
    
    kfree(message_buffer);
    return ret;
}

/* Send page response to destination node */
int vmotionos_send_page_response(struct vmotionos_page_response *response, char *dest_ip, int dest_port)
{
    struct vmotionos_message *typed_msg;
    char *message_buffer;
    size_t total_msg_size;
    struct vmotionos_message_params msg;
    int ret;
    
    if (!response || !dest_ip) return -EINVAL;
    
    total_msg_size = sizeof(struct vmotionos_message) + sizeof(struct vmotionos_page_response);
    
    vmotionos_debug("Sending page response to %s:%d for PID %d, addr 0x%lx\n",
                   dest_ip, dest_port, response->source_pid, response->page_addr);
    
    /* Create typed message */
    message_buffer = kzalloc(total_msg_size, GFP_KERNEL);
    if (!message_buffer) {
        vmotionos_err("Failed to allocate message buffer (%zu bytes)\n", total_msg_size);
        return -ENOMEM;
    }
    
    typed_msg = (struct vmotionos_message *)message_buffer;
    typed_msg->msg_type = MSG_TYPE_PAGE_RESPONSE;
    typed_msg->data_size = sizeof(struct vmotionos_page_response);
    
    /* Copy page response data */
    memcpy(typed_msg->data, response, sizeof(struct vmotionos_page_response));
    
    /* Send UDP message */
    msg.dest_ip = dest_ip;
    msg.dest_port = dest_port;
    msg.message = message_buffer;
    msg.msg_len = total_msg_size;
    
    ret = vmotionos_send_message(&msg);
    if (ret < 0) {
        vmotionos_err("Failed to send PAGE_RESPONSE message: %d\n", ret);
    } else {
        vmotionos_debug("Successfully sent PAGE_RESPONSE message\n");
    }
    
    kfree(message_buffer);
    return ret;
}
