diff --git a/src/portal/portal_mem.c b/src/portal/portal_mem.c index b21bd3d..567b781 100644 --- a/src/portal/portal_mem.c +++ b/src/portal/portal_mem.c @@ -73,6 +73,7 @@ struct task_struct *get_target_task_by_id(portal_region* mem_region) if (target_pid == CURRENT_PID_NUM) { igloo_pr_debug("igloo: Using current task (pid=%d)\n", current->pid); task = current; + get_task_struct(task); } else { igloo_pr_debug("igloo: Looking for task with pid=%d\n", target_pid); /* @@ -86,6 +87,8 @@ struct task_struct *get_target_task_by_id(portal_region* mem_region) */ rcu_read_lock(); task = pid_task(find_pid_ns(target_pid, &init_pid_ns), PIDTYPE_PID); + if (task) + get_task_struct(task); rcu_read_unlock(); } return task; diff --git a/src/portal/portal_osi.c b/src/portal/portal_osi.c index ff0c151..c758348 100644 --- a/src/portal/portal_osi.c +++ b/src/portal/portal_osi.c @@ -129,7 +129,7 @@ static void portal_get_vma_name(struct vm_area_struct *vma, char *buf, size_t bu void handle_op_osi_proc(portal_region *mem_region) { struct task_struct *task; - struct mm_struct *mm; + struct mm_struct *mm = NULL; struct osi_proc *proc; size_t name_len = 0; size_t total_size = sizeof(struct osi_proc); @@ -141,13 +141,13 @@ void handle_op_osi_proc(portal_region *mem_region) if (!task) { igloo_debug_osi("igloo: Handling HYPER_OP_OSI_PROC for NULL task\n"); mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } // Now we can safely use task->pid igloo_debug_osi("igloo: Handling HYPER_OP_OSI_PROC for PID %d\n", task->pid); - mm = task->mm; + mm = get_task_mm(task); // Initialize the OSI proc structure at the beginning of data buffer proc = (struct osi_proc *)data_buf; @@ -201,8 +201,13 @@ void handle_op_osi_proc(portal_region *mem_region) mem_region->header.size = (total_size); mem_region->header.op = (HYPER_RESP_READ_OK); + +cleanup: + if (mm) mmput(mm); + if (task) put_task_struct(task); } + void handle_op_osi_proc_handles(portal_region *mem_region) { struct task_struct *task; @@ -271,7 +276,7 @@ void handle_op_osi_proc_handles(portal_region *mem_region) void handle_op_osi_proc_exe(portal_region *mem_region) { struct task_struct *task; - struct mm_struct *mm; + struct mm_struct *mm = NULL; char *data_buf = PORTAL_DATA(mem_region); char *path; size_t pathlen = 0; @@ -282,19 +287,19 @@ void handle_op_osi_proc_exe(portal_region *mem_region) if (!task) { igloo_debug_osi("igloo: Handling HYPER_OP_OSI_PROC_EXE for NULL task\n"); mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } // Now we can safely use task->pid igloo_debug_osi("igloo: Handling HYPER_OP_OSI_PROC_EXE for PID %d\n", task->pid); - mm = task->mm; + mm = get_task_mm(task); if(mm && mm->exe_file) { char *path_buf = kmalloc(CHUNK_SIZE, GFP_KERNEL); if (!path_buf) { igloo_debug_osi("igloo: Failed to allocate memory for path buffer!\n"); mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } path = d_path(&mm->exe_file->f_path, path_buf, CHUNK_SIZE); if (!IS_ERR(path)) { @@ -315,8 +320,13 @@ void handle_op_osi_proc_exe(portal_region *mem_region) } mem_region->header.size = pathlen; + +cleanup: + if (mm) mmput(mm); + if (task) put_task_struct(task); } + void handle_op_osi_mappings(portal_region *mem_region) { struct task_struct *task; @@ -512,7 +522,7 @@ void handle_op_osi_mappings(portal_region *mem_region) void handle_op_osi_proc_mem(portal_region *mem_region) { struct task_struct *task; - struct mm_struct *mm; + struct mm_struct *mm = NULL; struct osi_proc_mem { __le64 start_brk; __le64 brk; @@ -528,18 +538,18 @@ void handle_op_osi_proc_mem(portal_region *mem_region) proc_mem->start_brk = 0; proc_mem->brk = 0; mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } // Now we can safely use task->pid igloo_debug_osi("igloo: Handling HYPER_OP_OSI_PROC_MEM for PID %d\n", task->pid); - mm = task->mm; + mm = get_task_mm(task); // Check if we have enough buffer space for the structure if (sizeof(struct osi_proc_mem) > CHUNK_SIZE) { mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } proc_mem = (struct osi_proc_mem *)PORTAL_DATA(mem_region); @@ -548,7 +558,7 @@ void handle_op_osi_proc_mem(portal_region *mem_region) proc_mem->start_brk = 0; proc_mem->brk = 0; mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } proc_mem->start_brk = (mm->start_brk); @@ -556,8 +566,13 @@ void handle_op_osi_proc_mem(portal_region *mem_region) mem_region->header.size = (sizeof(struct osi_proc_mem)); mem_region->header.op = (HYPER_RESP_READ_OK); + +cleanup: + if (mm) mmput(mm); + if (task) put_task_struct(task); } + void handle_op_read_procargs(portal_region *mem_region) { struct task_struct *task = get_target_task_by_id(mem_region); @@ -758,7 +773,7 @@ void handle_op_read_fds(portal_region *mem_region) header->total_count = 0; mem_region->header.size = (sizeof(struct osi_result_header)); mem_region->header.op = (HYPER_RESP_READ_FAIL); - return; + goto cleanup; } // ------------------------------------------------------------- @@ -773,7 +788,7 @@ void handle_op_read_fds(portal_region *mem_region) if (!task->fs) { mem_region->header.size = sizeof(struct osi_result_header); mem_region->header.op = HYPER_RESP_READ_OK; - return; + goto cleanup; } // Safely extract the pwd struct @@ -818,7 +833,7 @@ void handle_op_read_fds(portal_region *mem_region) } path_put(&pwd); - return; + goto cleanup; } // ------------------------------------------------------------- @@ -830,7 +845,7 @@ void handle_op_read_fds(portal_region *mem_region) header->total_count = 0; mem_region->header.size = sizeof(struct osi_result_header); mem_region->header.op = HYPER_RESP_READ_FAIL; - return; + goto cleanup; } // ------------------------------------------------------------- @@ -850,7 +865,7 @@ void handle_op_read_fds(portal_region *mem_region) task_unlock(task); mem_region->header.size = (sizeof(struct osi_result_header)); mem_region->header.op = (HYPER_RESP_READ_OK); - return; + goto cleanup; } files = task->files; @@ -939,8 +954,12 @@ void handle_op_read_fds(portal_region *mem_region) igloo_debug_osi("igloo: Returned %d file descriptors (total: %d), buffer used: %zu bytes\n", count, total_count, string_offset); + +cleanup: + if (task) put_task_struct(task); } + void handle_op_read_time(portal_region *mem_region) { mem_region->header.size = ktime_get_ns();