"""Base class with all infrastructure for kernel patchers.""" import struct, plistlib, threading from collections import defaultdict from capstone.arm64_const import ( ARM64_OP_REG, ARM64_OP_IMM, ARM64_REG_W0, ARM64_REG_X0, ARM64_REG_X8, ) from .kernel_asm import ( _cs, _rd32, _rd64, _PACIBSP_U32, _FUNC_BOUNDARY_U32S, ) class KernelPatcherBase: def __init__(self, data, verbose=False): self.data = data # bytearray (mutable) self.raw = bytes(data) # immutable snapshot for searching self.size = len(data) self.patches = [] # collected (offset, bytes, description) self._patch_by_off = {} # offset -> (patch_bytes, desc) self.verbose = verbose self._patch_num = 0 # running counter for clean one-liners self._emit_lock = threading.Lock() # Hot-path caches (search/disassembly is repeated heavily in JB mode). self._disas_cache = {} self._disas_cache_limit = 200_000 self._string_refs_cache = {} self._func_start_cache = {} self._log("[*] Parsing Mach-O segments …") self._parse_macho() self._log("[*] Discovering kext code ranges from __PRELINK_INFO …") self._discover_kext_ranges() self._log("[*] Building ADRP index …") self._build_adrp_index() self._log("[*] Building BL index …") self._build_bl_index() self._find_panic() self._log( f"[*] _panic at foff 0x{self.panic_off:X} " f"({len(self.bl_callers[self.panic_off])} callers)" ) # ── Logging ────────────────────────────────────────────────── def _log(self, msg): if self.verbose: print(msg) def _reset_patch_state(self): """Reset patch bookkeeping before a fresh find/apply pass.""" self.patches = [] self._patch_by_off = {} self._patch_num = 0 # ── Mach-O / segment parsing ───────────────────────────────── def _parse_macho(self): """Parse top-level Mach-O: discover BASE_VA, segments, code ranges.""" magic = _rd32(self.raw, 0) if magic != 0xFEEDFACF: raise ValueError(f"Not a 64-bit Mach-O (magic 0x{magic:08X})") self.code_ranges = [] # [(start_foff, end_foff), ...] self.all_segments = [] # [(name, vmaddr, fileoff, filesize, initprot)] self.base_va = None ncmds = struct.unpack_from(" 0: self.code_ranges.append((fileoff, fileoff + filesize)) off += cmdsize if self.base_va is None: raise ValueError("__TEXT segment not found — cannot determine BASE_VA") self.code_ranges.sort() total_mb = sum(e - s for s, e in self.code_ranges) / (1024 * 1024) self._log(f" BASE_VA = 0x{self.base_va:016X}") self._log( f" {len(self.code_ranges)} executable ranges, total {total_mb:.1f} MB" ) def _va(self, foff): return self.base_va + foff def _foff(self, va): return va - self.base_va # ── Kext range discovery ───────────────────────────────────── def _discover_kext_ranges(self): """Parse __PRELINK_INFO + embedded kext Mach-Os to find code section ranges.""" self.kext_ranges = {} # bundle_id -> (text_start, text_end) # Find __PRELINK_INFO segment prelink_info = None for name, vmaddr, fileoff, filesize, _ in self.all_segments: if name == "__PRELINK_INFO": prelink_info = (fileoff, filesize) break if prelink_info is None: self._log(" [-] __PRELINK_INFO not found, using __TEXT_EXEC for all") self._set_fallback_ranges() return foff, fsize = prelink_info pdata = self.raw[foff : foff + fsize] # Parse the XML plist xml_start = pdata.find(b"") if xml_start < 0 or xml_end < 0: self._log(" [-] __PRELINK_INFO plist not found") self._set_fallback_ranges() return xml = pdata[xml_start : xml_end + len(b"")] pl = plistlib.loads(xml) items = pl.get("_PrelinkInfoDictionary", []) # Kexts we need ranges for WANTED = { "com.apple.filesystems.apfs": "apfs", "com.apple.security.sandbox": "sandbox", "com.apple.driver.AppleMobileFileIntegrity": "amfi", } for item in items: bid = item.get("CFBundleIdentifier", "") tag = WANTED.get(bid) if tag is None: continue exec_addr = item.get("_PrelinkExecutableLoadAddr", 0) & 0xFFFFFFFFFFFFFFFF kext_foff = exec_addr - self.base_va if kext_foff < 0 or kext_foff >= self.size: continue # Parse this kext's embedded Mach-O to find __TEXT_EXEC.__text text_range = self._parse_kext_text_exec(kext_foff) if text_range: self.kext_ranges[tag] = text_range self._log( f" {tag:10s} __text: 0x{text_range[0]:08X} - 0x{text_range[1]:08X} " f"({(text_range[1] - text_range[0]) // 1024} KB)" ) # Derive the ranges used by patch methods self._set_ranges_from_kexts() def _parse_kext_text_exec(self, kext_foff): """Parse an embedded kext Mach-O header and return (__text start, end) in file offsets.""" if kext_foff + 32 > self.size: return None magic = _rd32(self.raw, kext_foff) if magic != 0xFEEDFACF: return None ncmds = struct.unpack_from(" self.size: break cmd, cmdsize = struct.unpack_from(" self.size: break sectname = ( self.raw[sect_off : sect_off + 16] .split(b"\x00")[0] .decode() ) if sectname == "__text": sect_addr = struct.unpack_from( "> 5) & 0x7FFFF immlo = (insn >> 29) & 0x3 imm = (immhi << 2) | immlo if imm & (1 << 20): imm -= 1 << 21 pc = self._va(off) page = (pc & ~0xFFF) + (imm << 12) self.adrp_by_page[page].append((off, rd)) n = sum(len(v) for v in self.adrp_by_page.values()) self._log(f" {n} ADRP entries, {len(self.adrp_by_page)} distinct pages") def _build_bl_index(self): """Index BL instructions by target offset.""" self.bl_callers = defaultdict(list) # target_off -> [caller_off, ...] for rng_start, rng_end in self.code_ranges: for off in range(rng_start, rng_end, 4): insn = _rd32(self.raw, off) if (insn & 0xFC000000) != 0x94000000: continue imm26 = insn & 0x3FFFFFF if imm26 & (1 << 25): imm26 -= 1 << 26 target = off + imm26 * 4 self.bl_callers[target].append(off) def _find_panic(self): """Find _panic: most-called function whose callers reference '@%s:%d' strings.""" candidates = sorted(self.bl_callers.items(), key=lambda x: -len(x[1]))[:15] for target_off, callers in candidates: if len(callers) < 2000: break confirmed = 0 for caller_off in callers[:30]: for back in range(caller_off - 4, max(caller_off - 32, 0), -4): insn = _rd32(self.raw, back) # ADD x0, x0, #imm if (insn & 0xFFC003E0) == 0x91000000: add_imm = (insn >> 10) & 0xFFF if back >= 4: prev = _rd32(self.raw, back - 4) if (prev & 0x9F00001F) == 0x90000000: # ADRP x0 immhi = (prev >> 5) & 0x7FFFF immlo = (prev >> 29) & 0x3 imm = (immhi << 2) | immlo if imm & (1 << 20): imm -= 1 << 21 pc = self._va(back - 4) page = (pc & ~0xFFF) + (imm << 12) str_foff = self._foff(page + add_imm) if 0 <= str_foff < self.size - 10: snippet = self.raw[str_foff : str_foff + 60] if b"@%s:%d" in snippet or b"%s:%d" in snippet: confirmed += 1 break break if confirmed >= 3: self.panic_off = target_off return self.panic_off = candidates[2][0] if len(candidates) > 2 else candidates[0][0] # ── Helpers ────────────────────────────────────────────────── def _disas_at(self, off, count=1): """Disassemble *count* instructions at file offset. Returns a list.""" if off < 0 or off >= self.size: return [] key = None if count <= 4: key = (off, count) cached = self._disas_cache.get(key) if cached is not None: return cached end = min(off + count * 4, self.size) code = bytes(self.raw[off:end]) insns = list(_cs.disasm(code, off, count)) if key is not None: if len(self._disas_cache) >= self._disas_cache_limit: self._disas_cache.clear() self._disas_cache[key] = insns return insns def _is_bl(self, off): """Return BL target file offset, or -1 if not a BL.""" insns = self._disas_at(off) if insns and insns[0].mnemonic == "bl": return insns[0].operands[0].imm return -1 def _is_cond_branch_w0(self, off): """Return True if instruction is a conditional branch on w0 (cbz/cbnz/tbz/tbnz).""" insns = self._disas_at(off) if not insns: return False i = insns[0] if i.mnemonic in ("cbz", "cbnz", "tbz", "tbnz"): return ( i.operands[0].type == ARM64_OP_REG and i.operands[0].reg == ARM64_REG_W0 ) return False def find_string(self, s, start=0): """Find string, return file offset of the enclosing C string start.""" if isinstance(s, str): s = s.encode() off = self.raw.find(s, start) if off < 0: return -1 # Walk backward to the preceding NUL — that's the C string start cstr = off while cstr > 0 and self.raw[cstr - 1] != 0: cstr -= 1 return cstr def find_string_refs(self, str_off, code_start=None, code_end=None): """Find all (adrp_off, add_off, dest_reg) referencing str_off via ADRP+ADD.""" key = (str_off, code_start, code_end) cached = self._string_refs_cache.get(key) if cached is not None: return cached target_va = self._va(str_off) target_page = target_va & ~0xFFF page_off = target_va & 0xFFF refs = [] for adrp_off, rd in self.adrp_by_page.get(target_page, []): if code_start is not None and adrp_off < code_start: continue if code_end is not None and adrp_off >= code_end: continue if adrp_off + 4 >= self.size: continue nxt = _rd32(self.raw, adrp_off + 4) # ADD (imm) 64-bit: 1001_0001_00_imm12_Rn_Rd if (nxt & 0xFFC00000) != 0x91000000: continue add_rn = (nxt >> 5) & 0x1F add_imm = (nxt >> 10) & 0xFFF if add_rn == rd and add_imm == page_off: add_rd = nxt & 0x1F refs.append((adrp_off, adrp_off + 4, add_rd)) self._string_refs_cache[key] = refs return refs def find_function_start(self, off, max_back=0x4000): """Walk backwards to find PACIBSP or STP x29,x30,[sp,#imm]. When STP x29,x30 is found, continues backward up to 0x20 more bytes to look for PACIBSP (ARM64e functions may have several STP instructions in the prologue before STP x29,x30). """ use_cache = max_back == 0x4000 if use_cache: cached = self._func_start_cache.get(off) if cached is not None: return cached result = -1 for o in range(off - 4, max(off - max_back, 0), -4): insn = _rd32(self.raw, o) if insn == _PACIBSP_U32: result = o break dis = self._disas_at(o) if dis and dis[0].mnemonic == "stp" and "x29, x30, [sp" in dis[0].op_str: # Check further back for PACIBSP (prologue may have # multiple STP instructions before x29,x30) for k in range(o - 4, max(o - 0x24, 0), -4): if _rd32(self.raw, k) == _PACIBSP_U32: result = k break if result < 0: result = o break if use_cache: self._func_start_cache[off] = result return result def _disas_n(self, buf, off, count): """Disassemble *count* instructions from *buf* at file offset *off*.""" end = min(off + count * 4, len(buf)) if off < 0 or off >= len(buf): return [] code = bytes(buf[off:end]) return list(_cs.disasm(code, off, count)) def _fmt_insn(self, insn, marker=""): """Format one capstone instruction for display.""" raw = insn.bytes hex_str = " ".join(f"{b:02x}" for b in raw) s = f" 0x{insn.address:08X}: {hex_str:12s} {insn.mnemonic:8s} {insn.op_str}" if marker: s += f" {marker}" return s def _print_patch_context(self, off, patch_bytes, desc): """Print disassembly before/after a patch site for debugging.""" ctx = 3 # instructions of context before and after # -- BEFORE (original bytes) -- lines = [f" ┌─ PATCH 0x{off:08X}: {desc}"] lines.append(" │ BEFORE:") start = max(off - ctx * 4, 0) before_insns = self._disas_n(self.raw, start, ctx + 1 + ctx) for insn in before_insns: if insn.address == off: lines.append(self._fmt_insn(insn, " ◄━━ PATCHED")) elif off < insn.address < off + len(patch_bytes): lines.append(self._fmt_insn(insn, " ◄━━ PATCHED")) else: lines.append(self._fmt_insn(insn)) # -- AFTER (new bytes) -- lines.append(" │ AFTER:") after_insns = self._disas_n(self.raw, start, ctx) for insn in after_insns: lines.append(self._fmt_insn(insn)) # Decode the patch bytes themselves patch_insns = list(_cs.disasm(patch_bytes, off, len(patch_bytes) // 4)) for insn in patch_insns: lines.append(self._fmt_insn(insn, " ◄━━ NEW")) # Trailing context after the patch trail_start = off + len(patch_bytes) trail_insns = self._disas_n(self.raw, trail_start, ctx) for insn in trail_insns: lines.append(self._fmt_insn(insn)) lines.append(f" └─") self._log("\n".join(lines)) def emit(self, off, patch_bytes, desc): """Record a patch and apply it to self.data immediately. Writing through to self.data ensures _find_code_cave() sees previously allocated shellcode and won't reuse the same cave. """ patch_bytes = bytes(patch_bytes) with self._emit_lock: existing = self._patch_by_off.get(off) if existing is not None: existing_bytes, existing_desc = existing if existing_bytes != patch_bytes: raise RuntimeError( f"Conflicting patch at 0x{off:08X}: " f"{existing_desc!r} vs {desc!r}" ) return self._patch_by_off[off] = (patch_bytes, desc) self.patches.append((off, patch_bytes, desc)) self.data[off : off + len(patch_bytes)] = patch_bytes self._patch_num += 1 patch_num = self._patch_num print(f" [{patch_num:2d}] 0x{off:08X} {desc}") if self.verbose: self._print_patch_context(off, patch_bytes, desc) def _find_by_string_in_range(self, string, code_range, label): """Find string, find ADRP+ADD ref in code_range, return ref list.""" str_off = self.find_string(string) if str_off < 0: self._log(f" [-] string not found: {string!r}") return [] refs = self.find_string_refs(str_off, code_range[0], code_range[1]) if not refs: self._log(f" [-] no code refs to {label} (str at 0x{str_off:X})") return refs # ── Chained fixup pointer decoding ─────────────────────────── def _decode_chained_ptr(self, val): """Decode an arm64e chained fixup pointer to a file offset. - auth rebase (bit63=1): foff = bits[31:0] - non-auth rebase (bit63=0): VA = (bits[50:43] << 56) | bits[42:0] """ if val == 0: return -1 if val & (1 << 63): # auth rebase return val & 0xFFFFFFFF else: # non-auth rebase target = val & 0x7FFFFFFFFFF # bits[42:0] high8 = (val >> 43) & 0xFF full_va = (high8 << 56) | target if full_va > self.base_va: return full_va - self.base_va return -1 # ═══════════════════════════════════════════════════════════════ # Per-patch finders # ═══════════════════════════════════════════════════════════════ _COND_BRANCH_MNEMONICS = frozenset( ( "b.eq", "b.ne", "b.cs", "b.hs", "b.cc", "b.lo", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al", "cbz", "cbnz", "tbz", "tbnz", ) ) def _decode_branch_target(self, off): """Decode conditional branch at off via capstone. Returns (target, mnemonic) or (None, None).""" insns = self._disas_at(off) if not insns: return None, None i = insns[0] if i.mnemonic in self._COND_BRANCH_MNEMONICS: # Target is always the last IMM operand for op in reversed(i.operands): if op.type == ARM64_OP_IMM: return op.imm, i.mnemonic return None, None def _get_kernel_text_range(self): """Return (start, end) file offsets of the kernel's own __TEXT_EXEC.__text. Parses fileset entries (LC_FILESET_ENTRY) to find the kernel component, then reads its Mach-O header to get the __TEXT_EXEC.__text section. Falls back to the full __TEXT_EXEC segment. """ # Try fileset entries ncmds = struct.unpack_from("= 0 else -1 if seatbelt_off < 0 or sandbox_off < 0: self._log(" [-] Sandbox/Seatbelt strings not found") return None self._log( f" [*] Sandbox string at foff 0x{sandbox_off:X}, " f"Seatbelt at 0x{seatbelt_off:X}" ) data_ranges = [] for name, vmaddr, fileoff, filesize, prot in self.all_segments: if name in ("__DATA_CONST", "__DATA") and filesize > 0: data_ranges.append((fileoff, fileoff + filesize)) for d_start, d_end in data_ranges: for i in range(d_start, d_end - 40, 8): val = _rd64(self.raw, i) if val == 0 or (val & (1 << 63)): continue if (val & 0x7FFFFFFFFFF) != sandbox_off: continue val2 = _rd64(self.raw, i + 8) if (val2 & (1 << 63)) or (val2 & 0x7FFFFFFFFFF) != seatbelt_off: continue val_ops = _rd64(self.raw, i + 32) if not (val_ops & (1 << 63)): ops_off = val_ops & 0x7FFFFFFFFFF self._log( f" [+] mac_policy_conf at foff 0x{i:X}, " f"mpc_ops -> 0x{ops_off:X}" ) return ops_off self._log(" [-] mac_policy_conf not found") return None def _read_ops_entry(self, table_off, index): """Read a function pointer from the ops table, handling chained fixups.""" off = table_off + index * 8 if off + 8 > self.size: return -1 val = _rd64(self.raw, off) if val == 0: return 0 return self._decode_chained_ptr(val)