import zstd
from multiprocessing import Pool
#from multiprocessing.pool import ThreadPool
#from concurrent.futures import ThreadPoolExecutor
import sys
import string
# from bitstring import ConstBitStream
import re
import time
import os

file_path = sys.argv[1]
missing_fileids_path = sys.argv[2]
process_count = 16
try:
    process_count = int(sys.argv[3])
except Exception as e:
    pass

data = []

file_start_offset = 0
window_half_size = 756
record_max_size = 1024

printable = string.ascii_letters + string.digits + string.punctuation + ' '

zero_neighbors = string.ascii_letters + '_'


def hex_escape(s):
    return ''.join(chr(c) if chr(c) in printable else r'\x{0:02x}'.format(c) for c in s)


def scan_range(params):
    global file_path
    offset = 0
    raw_data = open(file_path, 'rb')
    range_end = os.path.getsize(file_path)

    buf = None
    buf_size = 1024*1024*512

    report_counter = 0
    found_counter = 0

    fileid_number = params[0]
    fileid = params[1]

    fileid_str = fileid.decode("ascii")
    print(f'== [{fileid_number}] Searching for missing fileid: {fileid_str}')

    try:
        offsets = []
        records = []
        fileid_part = fileid[0:10]
        fileid_part_len = len(fileid_part)

        start_time = time.time()
        while offset < range_end + buf_size:
            pre_offset = 70
            post_offset = 724

            if len(fileid) < 20:
                return

            #print(f'== [{fileid_number}] == Searching for pattern {fileid_part} of fileid {fileid}')

            raw_data.seek(offset)
            data = raw_data.read(buf_size)

            # # Too slow
            #for i in range(0, len(data)-1):
            #    if data[i:i+fileid_part_len] == fileid_part:
            #        offsets.append(i)

            for m in re.compile(fileid_part).finditer(data):
                offsets.append(m.start())
                records.append(hex_escape(data[m.start()-pre_offset:m.start()+post_offset]))


            #print(f'++ [{fileid_number}] Found possible locations of {fileid_str} -> {len(offsets)}')
            offset += buf_size - window_half_size

        end_time = time.time()

        found_count = len(records)

        print(f'== [{fileid_number}] == Search for pattern {fileid_part} took {int(end_time-start_time)}s - found {len(offsets)} offsets')

        if len(offsets) == 0:
            print(f'** NO RECORDS FOR {fileid_str}')
            raw_data.close()
            return 0

        with open(fileid.decode("ascii"), 'a+') as f:
            i = 0
            for record in records:
                f.write(f'# Found at [{offsets[i]-pre_offset}-{offsets[i]+post_offset}]\n')
                f.write(f"records.append(b'{record}')\n")
                i += 1

    except Exception as e:
        raw_data.close()
        print(f'!!! Processing fileid {fileid_str} failed due to {str(e)}')
        return 0

    raw_data.close()

    return found_count

if __name__ == '__main__':
    print('@@ Loading missing file ids list...')

    missing_fileids = []
    with open(missing_fileids_path, 'r') as f:
        i = 0
        for line in f:
            missing_fileids.append((i, line.rstrip().encode('ascii')))
            i += 1

    #load_raw_data(file_path)

    # file_size = len(data)

    p = Pool(process_count)
    print('@@ Starting processing pool...')

    found = p.map(scan_range, missing_fileids)

    print(f'@@ Completed')
