]> git.ameliathe1st.gay Git - voyage-au-centre-des-fichiers.git/commitdiff
Moved the code to initialise an empty page in the PML4T to paging.cpp
authorAmelia Coutard <eliottulio.coutard@gmail.com>
Sun, 5 Mar 2023 03:00:02 +0000 (04:00 +0100)
committerAmelia Coutard <eliottulio.coutard@gmail.com>
Sun, 5 Mar 2023 03:00:02 +0000 (04:00 +0100)
kernel/src/interrupts.hpp
kernel/src/kernel.cpp
kernel/src/paging.cpp
kernel/src/paging.hpp
kernel/src/serial.cpp
kernel/src/serial.hpp

index 847fe3a44e21371817371710f3a17ca78c85c7c9..2ab8701f9ea2c5feb632b647fc5092ca2fa74e90 100644 (file)
@@ -59,6 +59,9 @@ struct isr_info {
        bool present : 1 = true;
 };
 
+template<typename... Ts>
+void assert(bool cond, const char* format, const Ts&... vs);
+
 template<size_t interrupt_nb>
 void enable_interrupts(const isr_info (&ISRs)[interrupt_nb], os::idt<interrupt_nb>& idt) {
        os::assert(is_APIC_builtin(), "No builtin APIC.");
index 8c1569cbe9baaf9127619dcd8711f021a2dc693c..7db816dade8ac56d06f39e163e471fd23f392d02 100644 (file)
@@ -207,31 +207,8 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
                // Allocate memory for segment:
                std::size_t nb_pages = (std::uint64_t(program_header.p_vaddr) % 0x1000 + program_header.p_memsz + 0x1000 - 1) / 0x1000;
                for (std::size_t i = 0; i < nb_pages; i++) {
-                       const auto indices = os::paging::calc_page_table_indices(program_header.p_vaddr + i * 0x1000);
-                       if (PML4T.contents[indices.pml4e].P == 0) {
-                               PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
-                               const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
-                               set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
-                       }
-                       os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
-                       if (PDPT.contents[indices.pdpe].non_page.P == 0) {
-                               PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-                               const auto PDT_alloc = os::paging::page_allocator.allocate(1);
-                               set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
-                       }
-                       os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
-                       if (PDT.contents[indices.pde].non_page.P == 0) {
-                               PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-                               const auto PT_alloc = os::paging::page_allocator.allocate(1);
-                               set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
-                       }
-                       os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
-                       os::assert(PT.contents[indices.pe].P == 0, "Process segments' memory overlaps the same pages.");
-                       PT.contents[indices.pe] = {.P = 1, .R_W = (program_header.flags & 2) >> 1, .U_S = 1, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
-                       const auto page_alloc = os::paging::page_allocator.allocate(1);
-                       set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
+                       os::paging::setup_page(PML4T, program_header.p_vaddr + i * 0x1000, (program_header.flags & 2) >> 1, 1, true);
                }
-
                // Initialise memory for segment:
                for (std::size_t i = 0; i < program_header.p_filesz; i++) {
                        program_header.p_vaddr[i] = os::phys_ptr<std::byte>(test_module.start_address.get_phys_addr())[program_header.p_offset + i];
@@ -245,35 +222,7 @@ extern "C" void kmain(unsigned long magic, os::phys_ptr<const multiboot2::info_s
        constexpr std::size_t stack_size = 16 * 0x1000 /* 64KiB */;
        std::byte * const stack = (std::byte*)0x0000'8000'0000'0000 - stack_size;
        for (std::size_t i = 0; i < stack_size / 0x1000 /* 64KiB */; i++) {
-               const auto indices = os::paging::calc_page_table_indices(stack + i * 0x1000);
-               os::assert(indices.pml4e < 256, "Userspace program must be in the lower-half of virtual memory.");
-               if (PML4T.contents[indices.pml4e].P == 0) {
-                       PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
-                       const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
-                       set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
-               }
-               os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
-               if (PDPT.contents[indices.pdpe].non_page.P == 0) {
-                       PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-                       const auto PDT_alloc = os::paging::page_allocator.allocate(1);
-                       set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
-               }
-               os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
-               if (PDT.contents[indices.pde].non_page.P == 0) {
-                       PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
-                       const auto PT_alloc = os::paging::page_allocator.allocate(1);
-                       set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
-               }
-               os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
-               os::assert(PT.contents[indices.pe].P == 0, "Process segments' memory overlaps the same pages.");
-               PT.contents[indices.pe] = {.P = 1, .R_W = 1, .U_S = 1, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
-               const auto page_alloc = os::paging::page_allocator.allocate(1);
-               set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
-       }
-
-       // Initialise stack to 0.
-       for (std::size_t i = 0; i < stack_size; i++) {
-               stack[i] = std::byte(0);
+               os::paging::setup_page(PML4T, stack + i * 0x1000, 1, 1, true);
        }
 
        asm volatile("invlpg (%0)" ::"r" (0x1000) : "memory");
index 0d02f0d346f3fbbc91e783045c20ee7f1252c8c3..d3dc1cbd40ec75d93e6c83748cbd1f7a23546fe3 100644 (file)
 #include "paging.hpp"
 #include "serial.hpp"
 
+void os::paging::setup_page(os::paging::PML4T& PML4T, const void* vaddr, bool R_W, bool U_S, bool must_be_unmapped) {
+       const auto indices = os::paging::calc_page_table_indices(vaddr);
+       if (PML4T.contents[indices.pml4e].P == 0) {
+               PML4T.contents[indices.pml4e] = {.P = 1, .R_W = 1, .U_S = U_S, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0};
+               const auto PDPT_alloc = os::paging::page_allocator.allocate(1);
+               set_base_address(PML4T.contents[indices.pml4e], os::phys_ptr<os::paging::PDPT>(PDPT_alloc.ptr.get_phys_addr()));
+       }
+       os::paging::PDPT& PDPT = *get_base_address(PML4T.contents[indices.pml4e]);
+       if (PDPT.contents[indices.pdpe].non_page.P == 0) {
+               PDPT.contents[indices.pdpe] = {.non_page = {.P = 1, .R_W = 1, .U_S = U_S, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
+               const auto PDT_alloc = os::paging::page_allocator.allocate(1);
+               set_base_address(PDPT.contents[indices.pdpe], os::phys_ptr<os::paging::PDT>(PDT_alloc.ptr.get_phys_addr()));
+       }
+       os::paging::PDT& PDT = *get_base_address(PDPT.contents[indices.pdpe]);
+       if (PDT.contents[indices.pde].non_page.P == 0) {
+               PDT.contents[indices.pde] = {.non_page = {.P = 1, .R_W = 1, .U_S = U_S, .PWT = 0, .PCD = 0, .base_address = 0, .NX = 0}};
+               const auto PT_alloc = os::paging::page_allocator.allocate(1);
+               set_base_address(PDT.contents[indices.pde], os::phys_ptr<os::paging::PT>(PT_alloc.ptr.get_phys_addr()));
+       }
+       os::paging::PT& PT = *get_base_address(PDT.contents[indices.pde]);
+       if (PT.contents[indices.pe].P == 0) {
+               PT.contents[indices.pe] = {.P = 1, .R_W = R_W, .U_S = U_S, .PWT = 0, .PCD = 0, .PAT = 0, .G = 0, .base_address = 0, .NX = 0};
+               const auto page_alloc = os::paging::page_allocator.allocate(1);
+               set_page_base_address(PT.contents[indices.pe], os::phys_ptr<os::paging::page>(page_alloc.ptr.get_phys_addr()));
+       } else {
+               os::assert(!must_be_unmapped, "Memory at address 0x{} has already been mapped.", std::uintptr_t(vaddr));
+       }
+}
+
 namespace {
 void on_all_pages(const os::paging::PT& PT, void f(os::paging::page*, os::phys_ptr<os::paging::page>, std::size_t), std::size_t PT_virt_address) {
        for (std::size_t i = 0; i < 512; i++) {
index 0d700ab6fdabc3f5a83af3de0d827b2a1cdd9206..a9a194692b1c9c4e35f2353cbb72778350042346 100644 (file)
@@ -235,6 +235,8 @@ constexpr page_table_indices calc_page_table_indices(const void* ptr) {
        };
 }
 
+void setup_page(PML4T& PML4T, const void* vaddr, bool R_W, bool U_S, bool must_be_unmapped);
+
 // For all present page mappings, calls f(virtual address, physical address, page size in bytes (4KiB, 2MiB or 1GiB)).
 void on_all_pages(const PML4T& PML4T, void f(page*, phys_ptr<page>, std::size_t));
 
index b482d9cefa13531a67084378ce28f0ef412eedb8..d02bc12c34eb9d471a33c4f7b747f382a2a5b5bc 100644 (file)
@@ -12,9 +12,9 @@
 // not, see <https://www.gnu.org/licenses/>.
 
 #include <cstddef>
-#include "serial.hpp"
 #include "utils.hpp"
 #include "interrupts.hpp"
+#include "serial.hpp"
 
 bool os::init_serial_port() {
        outb(serial_port + 1, 0x00); // Disable interrupts.
@@ -75,10 +75,3 @@ void os::print_formatted(const char* format, std::int64_t val) {
                os::print_formatted(format, std::uint64_t(val));
        }
 }
-void os::assert(bool cond, const char* diagnostic) {
-       if (!cond) {
-               os::print("Error: {}\n", diagnostic);
-               os::cli();
-               while (true) { os::hlt(); }
-       }
-}
index dfd5f48d0928ce3e39535bce1a1ff456db1612ac..c64de07e533de6c4d5e4ac6e9d4c75a2b8adb08d 100644 (file)
@@ -17,6 +17,9 @@
 
 namespace os {
 
+template<typename... Ts>
+void assert(bool cond, const char* format, const Ts&... vs);
+
 constexpr std::uint16_t serial_port{0x3F8};
 
 bool init_serial_port();
@@ -26,8 +29,6 @@ std::uint8_t read_serial();
 bool serial_transmit_empty();
 void write_serial(std::uint8_t v);
 
-void assert(bool cond, const char* diagnostic);
-
 void printc(char c);
 void print_formatted(const char* format, const char* val);
 void print_formatted(const char* format, std::uint64_t val);
@@ -67,7 +68,7 @@ void print(const char* format, const Ts&... vs) {
                if (format[i] == '{') {
                        i++;
                        if (format[i] == '\0') {
-                               os::assert(false, "Error in format string: unterminated '{}'.");
+                               os::assert(false, "Error in format string: unterminated '{{}'.");
                        } else if (format[i] == '{') {
                                printc('{');
                                continue;
@@ -75,7 +76,7 @@ void print(const char* format, const Ts&... vs) {
                                std::size_t format_spec_end = i;
                                while (format[format_spec_end] != '}') {
                                        if (format[format_spec_end++] == '\0') {
-                                               os::assert(false, "Error in format string: unterminated '{}'.");
+                                               os::assert(false, "Error in format string: unterminated '{{}'.");
                                        }
                                }
                                std::size_t n = arg_n;
@@ -106,7 +107,7 @@ void print(const char* format, const Ts&... vs) {
                                arg_n++;
                        }
                } else if (format[i] == '}') {
-                       os::assert(format[i + 1] == '}', "Error in format strin: unexpected '}'.");
+                       os::assert(format[i + 1] == '}', "Error in format string: unexpected '}'.");
                        i++;
                        printc('}');
                } else {
@@ -115,4 +116,18 @@ void print(const char* format, const Ts&... vs) {
        }
 }
 
+void cli();
+void hlt();
+
+template<typename... Ts>
+void assert(bool cond, const char* format, const Ts&... vs) {
+       if (!cond) {
+               print("Error: ");
+               print(format, vs...);
+               printc('\n');
+               cli();
+               while (true) { hlt(); }
+       }
+}
+
 } // namespace os