""" A Copy of the official implementation of ShareableList with Memory Leak fixed """ from multiprocessing import shared_memory from functools import partial import struct _encoding = "utf8" class ShareableList: """Pattern for a mutable list-like object shareable via a shared memory block. It differs from the built-in list type in that these lists can not change their overall length (i.e. no append, insert, etc.) Because values are packed into a memoryview as bytes, the struct packing format for any storable value must require no more than 8 characters to describe its format.""" # The shared memory area is organized as follows: # - 8 bytes: number of items (N) as a 64-bit integer # - (2 * N + 1) * 8 bytes: offsets from the start of the data area and # `struct` format string for each element # - K bytes: the data area storing item values (with encoding and size # depending on their respective types) # - N bytes: index into _back_transforms_mapping for each element # (for reconstructing the corresponding Python value) _types_mapping = { int: "q", float: "d", bool: "xxxxxxx?", str: "%ds", bytes: "%ds", None.__class__: "xxxxxx?x", } _alignment = 8 _back_transforms_mapping = { 0: lambda value: value, # int, float, bool 1: lambda value: value.rstrip(b"\x00").decode(_encoding), # str 2: lambda value: value.rstrip(b"\x00"), # bytes 3: lambda _value: None, # None } @staticmethod def _extract_recreation_code(value): """Used in concert with _back_transforms_mapping to convert values into the appropriate Python objects when retrieving them from the list as well as when storing them.""" if not isinstance(value, (str, bytes, None.__class__)): return 0 elif isinstance(value, str): return 1 elif isinstance(value, bytes): return 2 else: return 3 # NoneType def __init__(self, sequence=None, *, name=None): if name is None or sequence is not None: sequence = sequence or () _formats = [ self._types_mapping[type(item)] if not isinstance(item, (str, bytes)) else self._types_mapping[type(item)] % ( self._alignment * ( len( item.encode(_encoding) if isinstance(item, str) else item ) // self._alignment + 1 ), ) for item in sequence ] self._list_len = len(_formats) assert ( sum(len(fmt) <= 8 for fmt in _formats) == self._list_len ) offset = 0 # The offsets of each list element into the shared memory's # data area (0 meaning the start of the data area, not the start # of the shared memory area). _allocated_offsets_and_fmts = [0] for fmt in _formats: offset += ( self._alignment if fmt[-1] != "s" else int(fmt[:-1]) ) _allocated_offsets_and_fmts.append( fmt.encode(_encoding) ) _allocated_offsets_and_fmts.append(offset) _recreation_codes = [ self._extract_recreation_code(item) for item in sequence ] requested_size = struct.calcsize( "q" + self._format_size_metainfo + "".join(_formats) + self._format_back_transform_codes ) self.shm = shared_memory.SharedMemory( name, create=True, size=requested_size ) else: self.shm = shared_memory.SharedMemory(name) if sequence is not None: self._data_size = _allocated_offsets_and_fmts[-1] _enc = _encoding struct.pack_into( "q" + self._format_size_metainfo, self.shm.buf, 0, self._list_len, *(_allocated_offsets_and_fmts), ) struct.pack_into( "".join(_formats), self.shm.buf, self._offset_data_start, *( v.encode(_enc) if isinstance(v, str) else v for v in sequence ), ) struct.pack_into( self._format_back_transform_codes, self.shm.buf, self._offset_back_transform_codes, *(_recreation_codes), ) else: self._list_len = len( self ) # Obtains size from offset 0 in buffer. self._data_size = self._get_allocated_offset(self._list_len) def _get_allocated_offset_and_packing_format(self, position): "Get the allocated offset and packing formats for a single value stored at the given position" if (position >= self._list_len) or ( self._list_len < 0 ): # number of offsets is 1 more than self._list_len raise IndexError("Requested position out of range.") offset, fmt = struct.unpack_from( "q8s", self.shm.buf, (2 * position + 1) * 8 ) fmt = fmt.rstrip(b"\x00") fmt_as_str = fmt.decode(_encoding) return offset, fmt_as_str def _get_allocated_offset(self, position): "Get the allocated offset for a single value stored at the given position" if (position > self._list_len) or ( self._list_len < 0 ): # number of offsets is 1 more than self._list_len raise IndexError("Requested position out of range.") offset = struct.unpack_from( "q", self.shm.buf, (2 * position + 1) * 8 )[0] return offset def _get_back_transform(self, position): "Gets the back transformation function for a single value." if (position >= self._list_len) or (self._list_len < 0): raise IndexError("Requested position out of range.") transform_code = struct.unpack_from( "b", self.shm.buf, self._offset_back_transform_codes + position, )[0] transform_function = self._back_transforms_mapping[ transform_code ] return transform_function def _set_packing_format_and_transform( self, position, fmt_as_str, value ): """Sets the packing format and back transformation code for a single value in the list at the specified position.""" if (position >= self._list_len) or (self._list_len < 0): raise IndexError("Requested position out of range.") struct.pack_into( "8s", self.shm.buf, (2 * (position + 1)) * 8, fmt_as_str.encode(_encoding), ) transform_code = self._extract_recreation_code(value) struct.pack_into( "b", self.shm.buf, self._offset_back_transform_codes + position, transform_code, ) def __getitem__(self, position): position = ( position if position >= 0 else position + self._list_len ) try: ( item_offset, fmt_as_str, ) = self._get_allocated_offset_and_packing_format(position) offset = self._offset_data_start + item_offset (v,) = struct.unpack_from(fmt_as_str, self.shm.buf, offset) except IndexError: raise IndexError("index out of range") back_transform = self._get_back_transform(position) v = back_transform(v) return v def __setitem__(self, position, value): position = ( position if position >= 0 else position + self._list_len ) try: ( item_offset, current_format, ) = self._get_allocated_offset_and_packing_format(position) offset = self._offset_data_start + item_offset except IndexError: raise IndexError("assignment index out of range") if not isinstance(value, (str, bytes)): new_format = self._types_mapping[type(value)] encoded_value = value else: allocated_length = ( self._get_allocated_offset(position + 1) - item_offset ) encoded_value = ( value.encode(_encoding) if isinstance(value, str) else value ) if len(encoded_value) > allocated_length: raise ValueError( "bytes/str item exceeds available storage" ) if current_format[-1] == "s": new_format = current_format else: new_format = self._types_mapping[str] % ( allocated_length, ) self._set_packing_format_and_transform( position, new_format, value ) struct.pack_into( new_format, self.shm.buf, offset, encoded_value ) def __reduce__(self): return partial(self.__class__, name=self.shm.name), () def __len__(self): return struct.unpack_from("q", self.shm.buf, 0)[0] def __repr__(self): return f"{self.__class__.__name__}({list(self)}, name={self.shm.name!r})" @property def format(self): "The struct packing format used by all currently stored items." return "".join( self._get_allocated_offset_and_packing_format(i)[1] for i in range(self._list_len) ) @property def _format_size_metainfo(self): "The struct packing format used for the items' storage offsets and packing formats." return "q8s" * self._list_len + "q" @property def _format_back_transform_codes(self): "The struct packing format used for the items' back transforms." return "b" * self._list_len @property def _offset_data_start(self): # - 8 bytes for the list length # - (N + 1) * 8 bytes for the element offsets return (self._list_len * 2 + 2) * 8 @property def _offset_back_transform_codes(self): return self._offset_data_start + self._data_size def count(self, value): "L.count(value) -> integer -- return number of occurrences of value." return sum(value == entry for entry in self) def index(self, value): """L.index(value) -> integer -- return first index of value. Raises ValueError if the value is not present.""" for position, entry in enumerate(self): if value == entry: return position else: raise ValueError(f"{value!r} not in this container")