From c96eb734befabb3c42510d402aa801dca8eaa050 Mon Sep 17 00:00:00 2001
From: Amelia Coutard <eliottulio.coutard@gmail.com>
Date: Tue, 5 Dec 2023 15:57:31 +0100
Subject: [PATCH] Made page a template, to support all the different page sizes

---
 kernel/src/kernel.cpp | 22 +++++-----
 kernel/src/paging.cpp | 24 ++++++-----
 kernel/src/paging.hpp | 97 ++++++++++++++++++++++---------------------
 3 files changed, 73 insertions(+), 70 deletions(-)

diff --git a/kernel/src/kernel.cpp b/kernel/src/kernel.cpp
index 82ac43f..beddfac 100644
--- a/kernel/src/kernel.cpp
+++ b/kernel/src/kernel.cpp
@@ -24,7 +24,7 @@ extern "C" os::tss TSS;
 extern "C" char interrupt_stack_top;
 extern "C" os::paging::PML4T PML4T;
 
-os::paging::page bootstrap_pages_for_memory[32]; // 32 pages = 128 KiB
+os::paging::page<0> bootstrap_pages_for_memory[32]; // 32 pages = 128 KiB
 
 extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_start> info) {
 	os::assert(magic == 0x36D76289, "Incorrect magic number: wasn't booted with multiboot2.");
@@ -37,7 +37,7 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 	}
 
 	os::paging::page_allocator.deallocate({
-		.ptr = os::phys_ptr<os::paging::page>(reinterpret_cast<uintptr_t>(bootstrap_pages_for_memory) - 0xFFFF'FFFF'8000'0000),
+		.ptr = os::phys_ptr<os::paging::page<0>>(reinterpret_cast<uintptr_t>(bootstrap_pages_for_memory) - 0xFFFF'FFFF'8000'0000),
 		.size = sizeof(bootstrap_pages_for_memory) / sizeof(bootstrap_pages_for_memory[0])
 	});
 
@@ -94,8 +94,8 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 
 	{
 		struct {
-			os::phys_ptr<os::paging::page> start_address = nullptr;
-			os::phys_ptr<os::paging::page> end_address = nullptr;
+			os::phys_ptr<os::paging::page<0>> start_address = nullptr;
+			os::phys_ptr<os::paging::page<0>> end_address = nullptr;
 		} available_ram[50];
 		std::size_t available_ram_length = 0;
 
@@ -112,10 +112,10 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 				for (std::size_t i = 0; i < multiboot2::memory_map_number_of_entries(it); i++) {
 					if (multiboot2::memory_map_type(it, i) == 1) {
 						// Rounded up, to avoid including non-ram.
-						const os::phys_ptr<os::paging::page>
+						const os::phys_ptr<os::paging::page<0>>
 							s{(multiboot2::memory_map_base_addr(it, i) + 0x1000 - 1) / 0x1000 * 0x1000};
 						// Rounded down, to avoid including non-ram.
-						const os::phys_ptr<os::paging::page>
+						const os::phys_ptr<os::paging::page<0>>
 							e{(multiboot2::memory_map_base_addr(it, i) + multiboot2::memory_map_length(it, i)) / 0x1000 * 0x1000 - 0x1000};
 						if (s <= e) { // In the case where no full page is included in the ram section, don't add it.
 							os::assert(available_ram_length < 50, "Too much available RAM sections to initialise correctly. Will fix eventually, probably.");
@@ -139,13 +139,13 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 		os::assert(module_specified, "No modules specified in the multiboot. This is unsupported.");
 
 		// kernel_start and kernel_end are aligned to 4K by the linker script.
-		const os::phys_ptr<os::paging::page> kernel_s = ([]() {
-			os::phys_ptr<os::paging::page> ptr = nullptr;
+		const os::phys_ptr<os::paging::page<0>> kernel_s = ([]() {
+			os::phys_ptr<os::paging::page<0>> ptr = nullptr;
 			asm("mov $_kernel_phys_start,%0" : "=ri"(ptr));
 			return ptr;
 		})();
-		const os::phys_ptr<os::paging::page> kernel_e = ([]() {
-			os::phys_ptr<os::paging::page> ptr = nullptr;
+		const os::phys_ptr<os::paging::page<0>> kernel_e = ([]() {
+			os::phys_ptr<os::paging::page<0>> ptr = nullptr;
 			asm("mov $_kernel_phys_end,%0" : "=ri"(ptr));
 			return ptr - 1; // [s, e], not [s, e[
 		})();
@@ -186,7 +186,7 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
 
 	// Unmap low RAM, and free corresponding page.
 	PML4T.contents[0].non_page.P = false;
-	os::paging::page_allocator.deallocate({.ptr = os::phys_ptr<os::paging::page>(get_base_address(PML4T.contents[0]).get_phys_addr()), .size = 1});
+	os::paging::page_allocator.deallocate({.ptr = os::phys_ptr<os::paging::page<0>>(get_base_address(PML4T.contents[0]).get_phys_addr()), .size = 1});
 	os::invlpg((void*)0x0);
 
 	os::print("Loading ring 3 interrupts stack.\n");
diff --git a/kernel/src/paging.cpp b/kernel/src/paging.cpp
index 8fec68d..532b306 100644
--- a/kernel/src/paging.cpp
+++ b/kernel/src/paging.cpp
@@ -48,42 +48,44 @@ std::byte* os::paging::setup_page(os::paging::PML4T& PML4T, const void* vaddr, b
 		{.page = {.P = 1, .R_W = R_W, .U_S = U_S, .PWT = 0, .PCD = 0, .PAT = 0, .G = (indices.pml4e < 128) ? 0ul : 1ul, .base_address = 0, .NX = 0}};
 	const auto page_alloc = os::paging::page_allocator.allocate(1);
 	memset((void*)page_alloc.ptr, 0, 0x1000);
-	set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
+	set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page<0>>(page_alloc.ptr.get_phys_addr()));
 	invlpg(vaddr);
 	return (std::byte*)&*page_alloc.ptr;
 }
 
 namespace {
-void on_all_pages(const os::paging::PT& PT, void f(os::paging::page*, os::phys_ptr<os::paging::page>, std::size_t), std::size_t PT_virt_address) {
+void on_all_pages(const os::paging::PT& PT, void f(os::paging::page<0>*, os::phys_ptr<os::paging::page<0>>, std::size_t), std::size_t PT_virt_address) {
 	for (std::size_t i = 0; i < 512; i++) {
 		if (!PT.contents[i].page.P) {
 			continue;
 		}
 		std::uint64_t virt_address = PT_virt_address | (i * 0x1000ull);
-		f(reinterpret_cast<os::paging::page*>(virt_address), os::paging::get_page_base_address(PT.contents[i]), 0x1000ull);
+		f(reinterpret_cast<os::paging::page<0>*>(virt_address), os::paging::get_page_base_address(PT.contents[i]), 0x1000ull);
 	}
 }
-void on_all_pages(const os::paging::PDT& PDT, void f(os::paging::page*, os::phys_ptr<os::paging::page>, std::size_t), std::size_t PDT_virt_address) {
+void on_all_pages(const os::paging::PDT& PDT, void f(os::paging::page<0>*, os::phys_ptr<os::paging::page<0>>, std::size_t), std::size_t PDT_virt_address) {
 	for (std::size_t i = 0; i < 512; i++) {
 		if (!PDT.contents[i].page.P) {
 			continue;
 		}
 		std::uint64_t virt_address = PDT_virt_address | (i * 0x1000ull * 512ull);
 		if (os::paging::is_page(PDT.contents[i])) {
-			f(reinterpret_cast<os::paging::page*>(virt_address), os::paging::get_page_base_address(PDT.contents[i]), 0x1000ull * 512ull);
+			f(reinterpret_cast<os::paging::page<0>*>(virt_address),
+				os::phys_ptr<os::paging::page<0>>(os::paging::get_page_base_address(PDT.contents[i]).get_phys_addr()), 0x1000ull * 512ull);
 		} else {
 			on_all_pages(*os::paging::get_base_address(PDT.contents[i]), f, virt_address);
 		}
 	}
 }
-void on_all_pages(const os::paging::PDPT& PDPT, void f(os::paging::page*, os::phys_ptr<os::paging::page>, std::size_t), std::size_t PDPT_virt_address) {
+void on_all_pages(const os::paging::PDPT& PDPT, void f(os::paging::page<0>*, os::phys_ptr<os::paging::page<0>>, std::size_t), std::size_t PDPT_virt_address) {
 	for (std::size_t i = 0; i < 512; i++) {
 		if (!PDPT.contents[i].page.P) {
 			continue;
 		}
 		std::uint64_t virt_address = PDPT_virt_address | (i * 0x1000ull * 512ull * 512ull);
 		if (os::paging::is_page(PDPT.contents[i])) {
-			f(reinterpret_cast<os::paging::page*>(virt_address), os::paging::get_page_base_address(PDPT.contents[i]), 0x1000ull * 512ull * 512ull);
+			f(reinterpret_cast<os::paging::page<0>*>(virt_address),
+				os::phys_ptr<os::paging::page<0>>(os::paging::get_page_base_address(PDPT.contents[i]).get_phys_addr()), 0x1000ull * 512ull * 512ull);
 		} else {
 			on_all_pages(*os::paging::get_base_address(PDPT.contents[i]), f, virt_address);
 		}
@@ -92,7 +94,7 @@ void on_all_pages(const os::paging::PDPT& PDPT, void f(os::paging::page*, os::ph
 
 } // namespace
 
-void os::paging::on_all_pages(const os::paging::PML4T& PML4T, void f(page*, phys_ptr<page>, std::size_t)) {
+void os::paging::on_all_pages(const os::paging::PML4T& PML4T, void f(page<0>*, phys_ptr<page<0>>, std::size_t)) {
 	for (std::size_t i = 0; i < 512; i++) {
 		if (!PML4T.contents[i].non_page.P) {
 			continue;
@@ -121,7 +123,7 @@ os::paging::page_allocator_t::block os::paging::page_allocator_t::allocate(std::
 		return { .ptr = nullptr, .size = count };
 	}
 	if (begin->size == count) {
-		block result = { .ptr = phys_ptr<paging::page>{begin.get_phys_addr()}, .size = count };
+		block result = { .ptr = phys_ptr<paging::page<0>>{begin.get_phys_addr()}, .size = count };
 		begin = begin->next;
 		return result;
 	}
@@ -129,11 +131,11 @@ os::paging::page_allocator_t::block os::paging::page_allocator_t::allocate(std::
 	for (phys_ptr<page> it = begin; it != nullptr; it = it->next) {
 		if (it->size == count) {
 			prec->next = it->next;
-			return { .ptr = phys_ptr<paging::page>{it.get_phys_addr()}, .size = count };
+			return { .ptr = phys_ptr<paging::page<0>>{it.get_phys_addr()}, .size = count };
 		}
 		if (it->size > count) {
 			it->size -= count;
-			return { .ptr = phys_ptr<paging::page>{it.get_phys_addr()} + it->size, .size = count };
+			return { .ptr = phys_ptr<paging::page<0>>{it.get_phys_addr()} + it->size, .size = count };
 		}
 		prec = it;
 	}
diff --git a/kernel/src/paging.hpp b/kernel/src/paging.hpp
index 41ac1be..d4ed443 100644
--- a/kernel/src/paging.hpp
+++ b/kernel/src/paging.hpp
@@ -24,23 +24,24 @@ template <std::size_t depth> struct paging_entry;
 template <std::size_t depth> struct __attribute__((aligned(0x1000))) paging_table {
 	paging_entry<depth> contents[512];
 };
-template <> struct __attribute__((aligned(0x1000))) paging_table<0> {
-	std::byte contents[4096];
-};
-static_assert(sizeof(paging_table<0>) == 0x1000);
-static_assert(alignof(paging_table<0>) == 0x1000);
 
-using PML4T = paging_table<4>;
-using PML4E = paging_entry<4>;
-using PDPT = paging_table<3>;
-using PDPE = paging_entry<3>;
-using PDT = paging_table<2>;
-using PDE = paging_entry<2>;
-using PT = paging_table<1>;
-using PE = paging_entry<1>;
-using page = paging_table<0>;
+using PML4T = paging_table<3>;
+using PML4E = paging_entry<3>;
+using PDPT = paging_table<2>;
+using PDPE = paging_entry<2>;
+using PDT = paging_table<1>;
+using PDE = paging_entry<1>;
+using PT = paging_table<0>;
+using PE = paging_entry<0>;
+// Alignment should be the same as size, but that's literally too big for the compiler.
+template <std::size_t depth> struct __attribute__((aligned(0x1000))) page {
+	std::byte contents[0x1000ull << (9 * depth)];
+};
+static_assert(sizeof(page<0>) == 0x1000);
+static_assert(sizeof(page<1>) == 0x1000 * 512);
+static_assert(sizeof(page<2>) == 0x1000ull * 512 * 512);
 
-template<> struct paging_entry<4> {
+template<> struct paging_entry<3> {
 	struct {
 		std::uint64_t P : 1 = 0;
 		std::uint64_t R_W : 1;
@@ -56,12 +57,12 @@ template<> struct paging_entry<4> {
 		std::uint64_t NX : 1;
 	} non_page;
 };
-static_assert(sizeof(paging_table<4>) == 0x1000);
-static_assert(alignof(paging_table<4>) == 0x1000);
-static_assert(alignof(paging_entry<4>) == 8);
-static_assert(sizeof(paging_entry<4>) == 8);
+static_assert(sizeof(paging_table<3>) == 0x1000);
+static_assert(alignof(paging_table<3>) == 0x1000);
+static_assert(alignof(paging_entry<3>) == 8);
+static_assert(sizeof(paging_entry<3>) == 8);
 
-template<> struct paging_entry<3> {
+template<> struct paging_entry<2> {
 union {
 	struct {
 		std::uint64_t P : 1 = 0;
@@ -98,12 +99,12 @@ union {
 	} page;
 };
 };
-static_assert(sizeof(paging_table<3>) == 0x1000);
-static_assert(alignof(paging_table<3>) == 0x1000);
-static_assert(alignof(paging_entry<3>) == 8);
-static_assert(sizeof(paging_entry<3>) == 8);
+static_assert(sizeof(paging_table<2>) == 0x1000);
+static_assert(alignof(paging_table<2>) == 0x1000);
+static_assert(alignof(paging_entry<2>) == 8);
+static_assert(sizeof(paging_entry<2>) == 8);
 
-template <> struct paging_entry<2> {
+template <> struct paging_entry<1> {
 union {
 	struct {
 		std::uint64_t P : 1 = 0;
@@ -140,12 +141,12 @@ union {
 	} page;
 };
 };
-static_assert(sizeof(paging_table<2>) == 0x1000);
-static_assert(alignof(paging_table<2>) == 0x1000);
-static_assert(alignof(paging_entry<2>) == 8);
-static_assert(sizeof(paging_entry<2>) == 8);
+static_assert(sizeof(paging_table<1>) == 0x1000);
+static_assert(alignof(paging_table<1>) == 0x1000);
+static_assert(alignof(paging_entry<1>) == 8);
+static_assert(sizeof(paging_entry<1>) == 8);
 
-template <> struct paging_entry<1> {
+template <> struct paging_entry<0> {
 	struct {
 		std::uint64_t P : 1 = 0;
 		std::uint64_t R_W : 1;
@@ -163,10 +164,10 @@ template <> struct paging_entry<1> {
 		std::uint64_t NX : 1;
 	} page;
 };
-static_assert(sizeof(paging_table<1>) == 0x1000);
-static_assert(alignof(paging_table<1>) == 0x1000);
-static_assert(alignof(paging_entry<1>) == 8);
-static_assert(sizeof(paging_entry<1>) == 8);
+static_assert(sizeof(paging_table<0>) == 0x1000);
+static_assert(alignof(paging_table<0>) == 0x1000);
+static_assert(alignof(paging_entry<0>) == 8);
+static_assert(sizeof(paging_entry<0>) == 8);
 
 
 inline bool is_page(const PML4E __attribute__((unused))& PML4E) { return false; }
@@ -186,17 +187,17 @@ inline phys_ptr<PT> get_base_address(const PDE& PDE) {
 	os::assert(!is_page(PDE), "Tried to get non-page out of a page paging.");
 	return phys_ptr<PT>{PDE.non_page.base_address * 0x1000ull};
 }
-inline phys_ptr<page> get_page_base_address(const PDPE& PDPE) {
+inline phys_ptr<page<2>> get_page_base_address(const PDPE& PDPE) {
 	os::assert(is_page(PDPE), "Tried to get page out of a non-page paging.");
-	return phys_ptr<page>{PDPE.page.base_address * 0x1000ull * 512ull * 512ull};
+	return phys_ptr<page<2>>{PDPE.page.base_address * sizeof(page<2>)};
 }
-inline phys_ptr<page> get_page_base_address(const PDE& PDE) {
+inline phys_ptr<page<1>> get_page_base_address(const PDE& PDE) {
 	os::assert(is_page(PDE), "Tried to get page out of a non-page paging.");
-	return phys_ptr<page>{PDE.page.base_address * 0x1000ull * 512ull};
+	return phys_ptr<page<1>>{PDE.page.base_address * sizeof(page<1>)};
 }
-inline phys_ptr<page> get_page_base_address(const PE& PE) {
+inline phys_ptr<page<0>> get_page_base_address(const PE& PE) {
 	os::assert(is_page(PE), "Tried to get page out of a non-page paging.");
-	return phys_ptr<page>{PE.page.base_address * 0x1000ull};
+	return phys_ptr<page<0>>{PE.page.base_address * sizeof(page<0>)};
 }
 
 inline void set_base_address(PML4E& PML4E, phys_ptr<PDPT> ptr) {
@@ -211,17 +212,17 @@ inline void set_base_address(PDE& PDE, phys_ptr<PT> ptr) {
 	os::assert(!is_page(PDE), "Tried to get non-page out of a page paging.");
 	PDE.non_page.base_address = ptr.get_phys_addr() / 0x1000ull;
 }
-inline void set_page_base_address(PDPE& PDPE, phys_ptr<page> ptr) {
+inline void set_page_base_address(PDPE& PDPE, phys_ptr<page<2>> ptr) {
 	os::assert(is_page(PDPE), "Tried to get page out of a non-page paging.");
-	PDPE.page.base_address = ptr.get_phys_addr() / 0x1000ull / 512ull / 512ull;
+	PDPE.page.base_address = ptr.get_phys_addr() / sizeof(*ptr);
 }
-inline void set_page_base_address(PDE& PDE, phys_ptr<page> ptr) {
+inline void set_page_base_address(PDE& PDE, phys_ptr<page<1>> ptr) {
 	os::assert(is_page(PDE), "Tried to get page out of a non-page paging.");
-	PDE.page.base_address = ptr.get_phys_addr() / 0x1000ull / 512ull;
+	PDE.page.base_address = ptr.get_phys_addr() / sizeof(*ptr);
 }
-inline void set_page_base_address(PE& PE, phys_ptr<page> ptr) {
+inline void set_page_base_address(PE& PE, phys_ptr<page<0>> ptr) {
 	os::assert(is_page(PE), "Tried to get page out of a non-page paging.");
-	PE.page.base_address = ptr.get_phys_addr() / 0x1000ull;
+	PE.page.base_address = ptr.get_phys_addr() / sizeof(*ptr);
 }
 
 struct page_table_indices {
@@ -242,7 +243,7 @@ constexpr page_table_indices calc_page_table_indices(const void* ptr) {
 std::byte* setup_page(PML4T& PML4T, const void* vaddr, bool R_W, bool U_S);
 
 // For all present page mappings, calls f(virtual address, physical address, page size in bytes (4KiB, 2MiB or 1GiB)).
-void on_all_pages(const PML4T& PML4T, void f(page*, phys_ptr<page>, std::size_t));
+void on_all_pages(const PML4T& PML4T, void f(page<0>*, phys_ptr<page<0>>, std::size_t));
 
 void load_pml4t(phys_ptr<PML4T> PML4T);
 
@@ -252,7 +253,7 @@ extern page_allocator_t page_allocator;
 class page_allocator_t {
 public:
 	struct block {
-		phys_ptr<paging::page> ptr = nullptr;
+		phys_ptr<paging::page<0>> ptr = nullptr;
 		std::uint64_t size;
 	};
 
-- 
2.46.0