From: Amelia Coutard <eliottulio.coutard@gmail.com>
Date: Sun, 5 Mar 2023 03:00:02 +0000 (+0100)
Subject: Moved the code to initialise an empty page in the PML4T to paging.cpp
X-Git-Url: https://git.ameliathe1st.gay/?a=commitdiff_plain;h=271da5bf705606c24239a27c1c51afaac2bf0450;p=voyage-au-centre-des-fichiers.git

Moved the code to initialise an empty page in the PML4T to paging.cpp
---

diff --git a/kernel/src/interrupts.hpp b/kernel/src/interrupts.hpp
index 847fe3a..2ab8701 100644
--- a/kernel/src/interrupts.hpp
+++ b/kernel/src/interrupts.hpp
@@ -59,6 +59,9 @@ struct isr_info {
 	bool present : 1 = true;
 };
 
+template<typename... Ts>
+void assert(bool cond, const char* format, const Ts&... vs);
+
 template<size_t interrupt_nb>
 void enable_interrupts(const isr_info (&ISRs)[interrupt_nb], os::idt<interrupt_nb>& idt) {
 	os::assert(is_APIC_builtin(), "No builtin APIC.");
diff --git a/kernel/src/kernel.cpp b/kernel/src/kernel.cpp
index 8c1569c..7db816d 100644
--- a/kernel/src/kernel.cpp
+++ b/kernel/src/kernel.cpp
@@ -207,31 +207,8 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 		// Allocate memory for segment:
 		std::size_t nb_pages = (std::uint64_t(program_header.p_vaddr) % 0x1000 + program_header.p_memsz + 0x1000 - 1) / 0x1000;
 		for (std::size_t i = 0; i < nb_pages; i++) {
-			const auto indices = os::paging::calc_page_table_indices(program_header.p_vaddr + i * 0x1000);
-			if (PML4T.contents[indices.pml4e].P == 0) {
-				PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
-				const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
-				set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
-			}
-			os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
-			if (PDPT.contents[indices.pdpe].non_page.P == 0) {
-				PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-				const auto PDT_alloc = os::paging::page_allocator.allocate(1);
-				set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
-			}
-			os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
-			if (PDT.contents[indices.pde].non_page.P == 0) {
-				PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-				const auto PT_alloc = os::paging::page_allocator.allocate(1);
-				set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
-			}
-			os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
-			os::assert(PT.contents[indices.pe].P == 0, "Process segments' memory overlaps the same pages.");
-			PT.contents[indices.pe] = {.P = 1, .R_W = (program_header.flags & 2) >> 1, .U_S = 1, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
-			const auto page_alloc = os::paging::page_allocator.allocate(1);
-			set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
+			os::paging::setup_page(PML4T, program_header.p_vaddr + i * 0x1000, (program_header.flags & 2) >> 1, 1, true);
 		}
-
 		// Initialise memory for segment:
 		for (std::size_t i = 0; i < program_header.p_filesz; i++) {
 			program_header.p_vaddr[i] = os::phys_ptr<std::byte>(test_module.start_address.get_phys_addr())[program_header.p_offset + i];
@@ -245,35 +222,7 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 	constexpr std::size_t stack_size = 16 * 0x1000 /* 64KiB */;
 	std::byte * const stack = (std::byte*)0x0000'8000'0000'0000 - stack_size;
 	for (std::size_t i = 0; i < stack_size / 0x1000 /* 64KiB */; i++) {
-		const auto indices = os::paging::calc_page_table_indices(stack + i * 0x1000);
-		os::assert(indices.pml4e < 256, "Userspace program must be in the lower-half of virtual memory.");
-		if (PML4T.contents[indices.pml4e].P == 0) {
-			PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
-			const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
-			set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
-		}
-		os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
-		if (PDPT.contents[indices.pdpe].non_page.P == 0) {
-			PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-			const auto PDT_alloc = os::paging::page_allocator.allocate(1);
-			set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
-		}
-		os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
-		if (PDT.contents[indices.pde].non_page.P == 0) {
-			PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-			const auto PT_alloc = os::paging::page_allocator.allocate(1);
-			set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
-		}
-		os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
-		os::assert(PT.contents[indices.pe].P == 0, "Process segments' memory overlaps the same pages.");
-		PT.contents[indices.pe] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
-		const auto page_alloc = os::paging::page_allocator.allocate(1);
-		set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
-	}
-
-	// Initialise stack to 0.
-	for (std::size_t i = 0; i < stack_size; i++) {
-		stack[i] = std::byte(0);
+		os::paging::setup_page(PML4T, stack + i * 0x1000, 1, 1, true);
 	}
 
 	asm volatile("invlpg (%0)" ::"r" (0x1000) : "memory");
diff --git a/kernel/src/paging.cpp b/kernel/src/paging.cpp
index 0d02f0d..d3dc1cb 100644
--- a/kernel/src/paging.cpp
+++ b/kernel/src/paging.cpp
@@ -14,6 +14,35 @@
 #include "paging.hpp"
 #include "serial.hpp"
 
+void os::paging::setup_page(os::paging::PML4T& PML4T, const void* vaddr, bool R_W, bool U_S, bool must_be_unmapped) {
+	const auto indices = os::paging::calc_page_table_indices(vaddr);
+	if (PML4T.contents[indices.pml4e].P == 0) {
+		PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = U_S, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
+		const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
+		set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
+	}
+	os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
+	if (PDPT.contents[indices.pdpe].non_page.P == 0) {
+		PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = U_S, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
+		const auto PDT_alloc = os::paging::page_allocator.allocate(1);
+		set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
+	}
+	os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
+	if (PDT.contents[indices.pde].non_page.P == 0) {
+		PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = U_S, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
+		const auto PT_alloc = os::paging::page_allocator.allocate(1);
+		set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
+	}
+	os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
+	if (PT.contents[indices.pe].P == 0) {
+		PT.contents[indices.pe] = {.P = 1, .R_W = R_W, .U_S = U_S, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
+		const auto page_alloc = os::paging::page_allocator.allocate(1);
+		set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
+	} else {
+		os::assert(!must_be_unmapped, "Memory at address 0x{} has already been mapped.", std::uintptr_t(vaddr));
+	}
+}
+
 namespace {
 void on_all_pages(const os::paging::PT& PT, void f(os::paging::page*, os::phys_ptr<os::paging::page>, std::size_t), std::size_t PT_virt_address) {
 	for (std::size_t i = 0; i < 512; i++) {
diff --git a/kernel/src/paging.hpp b/kernel/src/paging.hpp
index 0d700ab..a9a1946 100644
--- a/kernel/src/paging.hpp
+++ b/kernel/src/paging.hpp
@@ -235,6 +235,8 @@ constexpr page_table_indices calc_page_table_indices(const void* ptr) {
 	};
 }
 
+void setup_page(PML4T& PML4T, const void* vaddr, bool R_W, bool U_S, bool must_be_unmapped);
+
 // For all present page mappings, calls f(virtual address, physical address, page size in bytes (4KiB, 2MiB or 1GiB)).
 void on_all_pages(const PML4T& PML4T, void f(page*, phys_ptr<page>, std::size_t));
 
diff --git a/kernel/src/serial.cpp b/kernel/src/serial.cpp
index b482d9c..d02bc12 100644
--- a/kernel/src/serial.cpp
+++ b/kernel/src/serial.cpp
@@ -12,9 +12,9 @@
 // not, see <https://www.gnu.org/licenses/>.
 
 #include <cstddef>
-#include "serial.hpp"
 #include "utils.hpp"
 #include "interrupts.hpp"
+#include "serial.hpp"
 
 bool os::init_serial_port() {
 	outb(serial_port + 1, 0x00); // Disable interrupts.
@@ -75,10 +75,3 @@ void os::print_formatted(const char* format, std::int64_t val) {
 		os::print_formatted(format, std::uint64_t(val));
 	}
 }
-void os::assert(bool cond, const char* diagnostic) {
-	if (!cond) {
-		os::print("Error: {}\n", diagnostic);
-		os::cli();
-		while (true) { os::hlt(); }
-	}
-}
diff --git a/kernel/src/serial.hpp b/kernel/src/serial.hpp
index dfd5f48..c64de07 100644
--- a/kernel/src/serial.hpp
+++ b/kernel/src/serial.hpp
@@ -17,6 +17,9 @@
 
 namespace os {
 
+template<typename... Ts>
+void assert(bool cond, const char* format, const Ts&... vs);
+
 constexpr std::uint16_t serial_port{0x3F8};
 
 bool init_serial_port();
@@ -26,8 +29,6 @@ std::uint8_t read_serial();
 bool serial_transmit_empty();
 void write_serial(std::uint8_t v);
 
-void assert(bool cond, const char* diagnostic);
-
 void printc(char c);
 void print_formatted(const char* format, const char* val);
 void print_formatted(const char* format, std::uint64_t val);
@@ -67,7 +68,7 @@ void print(const char* format, const Ts&... vs) {
 		if (format[i] == '{') {
 			i++;
 			if (format[i] == '\0') {
-				os::assert(false, "Error in format string: unterminated '{}'.");
+				os::assert(false, "Error in format string: unterminated '{{}'.");
 			} else if (format[i] == '{') {
 				printc('{');
 				continue;
@@ -75,7 +76,7 @@ void print(const char* format, const Ts&... vs) {
 				std::size_t format_spec_end = i;
 				while (format[format_spec_end] != '}') {
 					if (format[format_spec_end++] == '\0') {
-						os::assert(false, "Error in format string: unterminated '{}'.");
+						os::assert(false, "Error in format string: unterminated '{{}'.");
 					}
 				}
 				std::size_t n = arg_n;
@@ -106,7 +107,7 @@ void print(const char* format, const Ts&... vs) {
 				arg_n++;
 			}
 		} else if (format[i] == '}') {
-			os::assert(format[i + 1] == '}', "Error in format strin: unexpected '}'.");
+			os::assert(format[i + 1] == '}', "Error in format string: unexpected '}'.");
 			i++;
 			printc('}');
 		} else {
@@ -115,4 +116,18 @@ void print(const char* format, const Ts&... vs) {
 	}
 }
 
+void cli();
+void hlt();
+
+template<typename... Ts>
+void assert(bool cond, const char* format, const Ts&... vs) {
+	if (!cond) {
+		print("Error: ");
+		print(format, vs...);
+		printc('\n');
+		cli();
+		while (true) { hlt(); }
+	}
+}
+
 } // namespace os