From db35af8fc40ac20986a4e3ce926d1da3b1f427ce Mon Sep 17 00:00:00 2001
From: Amelia Coutard <eliottulio.coutard@gmail.com>
Date: Mon, 27 Feb 2023 13:32:10 +0100
Subject: [PATCH] Added missing stack to ftl_to_userspace

---
 kernel/src/kernel.cpp    | 45 +++++++++++++++++++++++++++++++++-------
 kernel/src/paging.hpp    |  2 +-
 kernel/src/ring3.S       |  1 +
 kernel/src/ring3.hpp     |  3 ++-
 test_module/module.mk    |  2 +-
 test_module/src/test.cpp |  8 +++++--
 6 files changed, 49 insertions(+), 12 deletions(-)

diff --git a/kernel/src/kernel.cpp b/kernel/src/kernel.cpp
index 1485a17..a2cbd58 100644
--- a/kernel/src/kernel.cpp
+++ b/kernel/src/kernel.cpp
@@ -231,6 +231,43 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 			program_header.p_vaddr[i] = std::byte(0);
 		}
 	}
+
+	// Allocate and map stack.
+	constexpr std::size_t stack_size = 16 * 0x1000 /* 64KiB */;
+	std::byte * const stack = (std::byte*)0x0000'8000'0000'0000 - stack_size;
+	for (std::size_t i = 0; i < stack_size / 0x1000 /* 64KiB */; i++) {
+		const auto indices = os::paging::calc_page_table_indices(stack + i * 0x1000);
+		os::assert(indices.pml4e < 256, "Userspace program must be in the lower-half of virtual memory.");
+		if (PML4T.contents[indices.pml4e].P == 0) {
+			PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
+			const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
+			set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
+		}
+		os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
+		if (PDPT.contents[indices.pdpe].non_page.P == 0) {
+			PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
+			const auto PDT_alloc = os::paging::page_allocator.allocate(1);
+			set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
+		}
+		os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
+		if (PDT.contents[indices.pde].non_page.P == 0) {
+			PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
+			const auto PT_alloc = os::paging::page_allocator.allocate(1);
+			set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
+		}
+		os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
+		if (PT.contents[indices.pe].P == 0) {
+			PT.contents[indices.pe] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
+			const auto page_alloc = os::paging::page_allocator.allocate(1);
+			set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
+		}
+	}
+
+	// Initialise stack to 0.
+	for (std::size_t i = 0; i < stack_size; i++) {
+		stack[i] = std::byte(0);
+	}
+
 	asm volatile("invlpg (%0)" ::"r" (0x1000) : "memory");
 
 	os::paging::on_all_pages(PML4T, [](os::paging::page* virt, os::phys_ptr<os::paging::page> phys, std::size_t size) {
@@ -240,12 +277,6 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 		os::print("{}->{} ({})\n", std::uint64_t(virt), std::uint64_t(phys.get_phys_addr()), size);
 	});
 
-	// Allow kernel in ring 3. Otherwise, we just immediately page fault.
-	// Will make it a module soon-ish.
-	get_base_address(PML4T.contents[511])->contents[510].page.U_S = true;
-	get_base_address(PML4T.contents[511])->contents[511].page.U_S = true;
-	PML4T.contents[511].U_S = true;
-
 	os::print("Loading ring 3 interrupts stack.\n");
 	os::set_ring0_stack(TSS, std::uint64_t(&interrupt_stack_top));
 	os::print("Loading TSS.\n");
@@ -253,5 +284,5 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 	os::print("Enabling syscalls.\n");
 	os::enable_syscalls();
 	os::print("Moving to ring 3.\n");
-	os::ftl_to_userspace(elf_header.entry);
+	os::ftl_to_userspace(elf_header.entry, stack + stack_size);
 }
diff --git a/kernel/src/paging.hpp b/kernel/src/paging.hpp
index 677caa2..faa8168 100644
--- a/kernel/src/paging.hpp
+++ b/kernel/src/paging.hpp
@@ -213,7 +213,7 @@ struct page_table_indices {
 	std::uint16_t pde;
 	std::uint16_t pe;
 };
-constexpr page_table_indices calc_page_table_indices(void* ptr) {
+constexpr page_table_indices calc_page_table_indices(const void* ptr) {
 	return {
 		.pml4e = std::uint16_t((std::uint64_t(ptr) >> 39) & 0x1FF),
 		.pdpe  = std::uint16_t((std::uint64_t(ptr) >> 30) & 0x1FF),
diff --git a/kernel/src/ring3.S b/kernel/src/ring3.S
index f648cf1..fb88164 100644
--- a/kernel/src/ring3.S
+++ b/kernel/src/ring3.S
@@ -3,6 +3,7 @@
 .globl ftl_to_userspace
 ftl_to_userspace:
 	mov %rdi, %rcx
+	mov %rsi, %rsp
 	mov $0x202, %r11 # EFLAGS
 	sysretq
 
diff --git a/kernel/src/ring3.hpp b/kernel/src/ring3.hpp
index 6bb02b3..da86967 100644
--- a/kernel/src/ring3.hpp
+++ b/kernel/src/ring3.hpp
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <cstdint>
+#include <cstddef>
 
 namespace os {
 
@@ -24,6 +25,6 @@ struct __attribute__((packed)) tss {
 void set_ring0_stack(tss& tss, std::uint64_t stack);
 extern "C" void load_tss();
 void enable_syscalls();
-extern "C" void ftl_to_userspace(void* program);
+extern "C" void ftl_to_userspace(void* program, std::byte* stack);
 
 }
diff --git a/test_module/module.mk b/test_module/module.mk
index 437b5bf..ddfcc32 100644
--- a/test_module/module.mk
+++ b/test_module/module.mk
@@ -13,7 +13,7 @@ EXEC_NAME := test-module.elf64
 TO_ISO += isodir/boot/$(EXEC_NAME)
 TO_CLEAN += $(OUT_DIR) $(DEP_DIR)
 
-LOCAL_CXXFLAGS := $(CXXFLAGS) -O0
+LOCAL_CXXFLAGS := $(CXXFLAGS)
 LOCAL_LDFLAGS := $(LDFLAGS) -T test_module/linker.ld -z max-page-size=0x1000
 
 CPPOBJS := $(patsubst $(SRC_DIR)%,$(OUT_DIR)%.o,$(shell find $(SRC_DIR) -name '*.cpp'))
diff --git a/test_module/src/test.cpp b/test_module/src/test.cpp
index b2d8219..24da563 100644
--- a/test_module/src/test.cpp
+++ b/test_module/src/test.cpp
@@ -1,13 +1,17 @@
+#include <stddef.h>
+
 extern "C" void print(char c);
 
 void printstr(const char* str) {
-	for (int i = 0; str[i] != '\0'; i++) {
+	for (size_t i = 0; str[i] != '\0'; i++) {
 		print(str[i]);
 	}
 }
 
 extern "C" void _start() {
 	const char* str = "ACAB cependant.\n";
-	printstr(str);
+	for (size_t i = 0; i < 16; i++) {
+		printstr(str);
+	}
 	while (true) {}
 }
-- 
2.46.0