import zstd
# from multiprocessing import Pool
#from multiprocessing.pool import ThreadPool
from concurrent.futures import ThreadPoolExecutor
import sys
import string
# from bitstring import ConstBitStream
import re
import time

file_path = sys.argv[1]
missing_fileids_path = sys.argv[2]
process_count = 16
try:
    process_count = int(sys.argv[3])
except Exception as e:
    pass

data = []

def load_raw_data(path):
    global data
    print(f'@@ Opening compressed disk image {path}...')
    with open(path, 'rb') as f:
        print('@@ Reading compressed disk image into memory...')
        file_data = f.read()
        print('@@ Decompressing...')
        data = zstd.decompress(file_data)


file_start_offset = 0

printable = string.ascii_letters + string.digits + string.punctuation + ' '

zero_neighbors = string.ascii_letters + '_'


def hex_escape(s):
    return ''.join(chr(c) if chr(c) in printable else r'\x{0:02x}'.format(c) for c in s)


def scan_range(params):
    global data
    try:
        fileid_number = params[0]
        fileid = params[1]

        fileid_str = fileid.decode("ascii")
        print(f'== [{fileid_number}] Searching for missing fileid: {fileid_str}')
        pre_offset = 70
        post_offset = 724

        if len(fileid) < 20:
            return

        fileid_part = fileid[(int)(len(fileid)/2)-8:(int)(len(fileid)/2)+8]
        fileid_part_len = len(fileid_part)

        start_time = time.time()
        offsets = []

        print(f'== [{fileid_number}] == Searching for pattern {fileid_part} of fileid {fileid}')

        # # Too slow
        #for i in range(0, len(data)-1):
        #    if data[i:i+fileid_part_len] == fileid_part:
        #        offsets.append(i)

        for m in re.compile(fileid_part).finditer(data):
            offsets.append(m.start())

        end_time = time.time()

        print(f'== [{fileid_number}] == Search for pattern {fileid_part} took {int(end_time-start_time)}s - found {len(offsets)} offsets')

        found_count = len(offsets)

        #print(f'++ [{fileid_number}] Found possible locations of {fileid_str} -> {len(offsets)}')

        if len(offsets) == 0:
            print(f'** NO RECORDS FOR {fileid_str}')
            return 0

        with open(fileid.decode("ascii"), 'w') as f:
            for offset in offsets:
                f.write(f'==[{offset-pre_offset}-{offset+post_offset}]===\n')
                f.write(hex_escape(data[offset-pre_offset:offset+post_offset]))
                f.write(f'\n=====\n\n')

        return found_count
    except Exception as e:
        print(f'!!! Processing fileid {fileid_str} failed due to {str(e)}')
        return 0


if __name__ == '__main__':
    print('@@ Loading missing file ids list...')

    missing_fileids = []
    with open(missing_fileids_path, 'r') as f:
        i = 0
        for line in f:
            missing_fileids.append((i, line.rstrip().encode('ascii')))
            i += 1

    load_raw_data(file_path)

    file_size = len(data)

    executor = ThreadPoolExecutor(process_count)
    print('@@ Starting processing pool...')

    found = executor.map(scan_range, missing_fileids)

    found_sum = 0
    for partial in found:
        found_sum += partial

    print(f'@@ Completed - found {found_sum} possible records')
