from typing import overload from .Utils import uint_le, u8_le, u16_le, u32_le class FileAllocationEntry: @overload def __init__(self, offset: int = 0, length: int = 0, /): ... @overload def __init__(self, data: bytes|None = None, /): ... def __init__(self, *args): if len(args) == 2: self.start = args[0] self.end = args[0] + args[1] elif not args or args[0] is None: self.start = 0 self.end = 0 else: data = args[0] self.start = uint_le(data[:4]) self.end = uint_le(data[4:8]) def encode(self) -> bytes: return b''.join([u32_le(self.start), u32_le(self.end)]) __bytes__ = encode def __len__(self): return 8 class DirectoryTableEntry: @overload def __init__(self, start: int, file_id: int, parent_id: int, /): ... @overload def __init__(self, data: bytes|None = None, /): ... def __init__(self, *args): if len(args) == 3: self.start = args[0] self.file_id = args[1] self.parent_id = args[2] elif not args or args[0] is None: self.start = 0 self.file_id = 0 self.parent_id = 0 else: data = args[0] self.start = uint_le(data[:4]) self.file_id = uint_le(data[4:6]) self.parent_id = uint_le(data[6:8]) def encode(self) -> bytes: return b''.join([ u32_le(self.start), u16_le(self.file_id), u16_le(self.parent_id) ]) __bytes__ = encode def __len__(self): return 8 class EntryNameTableEntry: @overload def __init__(self, length: int = 0, /): ... @overload def __init__(self, data: bytes|None = None, /): ... def __init__(self, arg = 0): if isinstance(arg, int): self.entry_name_length = arg elif isinstance(arg, bytes): self.entry_name_length = arg[0] else: self.entry_name_length = 0 def encode(self) -> bytes: return u8_le(self.entry_name_length) __bytes__ = encode class EntryNameTableEndOfDirectoryEntry(EntryNameTableEntry): def __init__(self, *_): super().__init__() def __len__(self): return 1 class EntryNameTableFileEntry(EntryNameTableEntry): @overload def __init__(self, name: str = '', /): ... @overload def __init__(self, data: bytes|None = None, /): ... def __init__(self, arg = ''): if arg is None: super().__init__() return elif isinstance(arg, str): super().__init__(len(arg)) self.entry_name = arg else: super().__init__(arg) self.entry_name = arg[1:self.entry_name_length+1].decode('shift-jis') def encode(self) -> bytes: self.__init__(self.entry_name) return b''.join([ super().encode(), self.entry_name.encode('shift-jis') ]) def __len__(self): return self.entry_name_length + 1 class EntryNameTableDirectoryEntry(EntryNameTableEntry): @overload def __init__(self, name: str = '', directory_id: int = 0, /): ... @overload def __init__(self, data: bytes|None = None, /): ... def __init__(self, *args): if len(args) == 2: super().__init__(len(args[0]) | 0x80) self.entry_name = args[0] self.directory_id = args[1] elif not args or args[0] is None: super().__init__() self.entry_name = '' self.directory_id = 0 else: data = args[0] super().__init__(data[:1]) enl = self.entry_name_length ^ 0x80 self.entry_name = data[1 : enl+1].decode('shift-jis') self.directory_id = uint_le(data[enl+1 : enl+3]) def encode(self) -> bytes: return b''.join([ super().encode(), self.entry_name.encode('shift-jis'), u16_le(self.directory_id) ]) def __len__(self): return (self.entry_name_length ^ 0x80) + 3