import snappy
import json
import os
from multiprocessing import Pool
from bitstring import ConstBitStream
import string
import sys
import re

# d = None
#file_path = 'recover_couch_slice.dat'
file_path = sys.argv[1] #'recover_full_no_empty_chunk01'
file_size = os.path.getsize(file_path)
#missing_ids_file = 'local_media_fileids_batch_2_missing.csv'

#file_size =        21762560000
# file_start_offset = 3160916992
file_start_offset = 0
process_count = int(sys.argv[2])

chunk_size_per_process = (int)(file_size/process_count)
chunk_count = (int)(file_size/chunk_size_per_process)+1

possible_match = []

window_half_size = 756
record_max_size = 1024

printable = string.ascii_letters + string.digits + string.punctuation + ' '

zero_neighbors = string.ascii_letters + '_'

def hex_escape(s):
    return ''.join(chr(c) if chr(c) in printable else r'\x{0:02x}'.format(c) for c in s)


def search_buf_header(buf, offset):
    matches = []
    for i in range(0, len(buf)-4):
        if buf[i] == 0x01 and buf[i+1] == 0x00 and buf[i+2] == 0x00 and buf[i+3] == 0x00:
            matches.append(offset+i)
    return matches


def search_buf_reg(buf, offset):
    matches = []
    for i in range(0, len(buf)-3):
        if buf[i] == ord('R') and buf[i+1] == ord('E') and buf[i+2] == ord('G'):
            matches.append(offset+i-(int)(record_max_size/2))
    return matches


def search_buf_mode(buf, offset):
    matches = []
    for i in range(0, len(buf)-3):
        if buf[i] == ord('"') and buf[i+1] == ord('m') and buf[i+2] == ord('o'):
            matches.append(offset+i-(int)(record_max_size/2))
    return matches


def search_buf_dir(buf, offset):
    matches = []
    for i in range(0, len(buf)-3):
        if buf[i] == ord('D') and buf[i+1] == ord('I') and buf[i+2] == ord('R'):
            matches.append(offset+i-(int)(record_max_size/2))
    return matches


def search_buf_key(buf, offset):
    matches = []
    for i in range(0, len(buf)-4):
        if buf[i] == ord('_') and buf[i+1] == ord('k') and buf[i+2] == ord('e'):
            matches.append(offset+i-24)
    return matches


def search_buf_scope(buf, offset):
    matches = []
    for i in range(0, len(buf)-4):
        if buf[i] == ord('_') and buf[i+1] == ord('s') and buf[i+2] == ord('c'):
            matches.append(offset+i-64)
    return matches


def search_buf_seq(buf, offset):
    matches = []
    for i in range(0, len(buf)-4):
        if buf[i] == ord('_') and buf[i+1] == ord('s') and buf[i+2] == ord('e'):
            matches.append(offset+i-(int)(record_max_size/2))
    return matches


def search_buf_file_meta(buf, offset):
    matches = []
    for i in range(0, len(buf)-3):
        if buf[i] == ord('"') and buf[i+1] == ord('f') and buf[i+2] == ord('i'):
            matches.append(offset+i-324)
    return matches

def search_buf_file_uuid(buf, offset):
    matches = []
    for i in range(0, len(buf)-4):
        if buf[i] == ord('t') and buf[i+1] == ord('_') and buf[i+2] == ord('u') and buf[i+3] == ord('u'):
            matches.append(offset+i)
    return matches


def search_buf_by_fileid(buf, offset, fileid):
    s = ConstBitStream(bytes=buf)
    matches = []
    found = s.find(fileid, bytealigned=True)
    if found:
        matches.append(offset+found[0])
    return matches


def extract_document_around_offset(match_offset, raw_data, fileid = ''):
    #raw_data.seek(max(0, match_offset-window_half_size*2))
    #match_buf_raw = raw_data.read(window_half_size*4)
    raw_data.seek(max(0, match_offset)) #-window_half_size*2))
    match_buf_raw = raw_data.read(record_max_size+24)

    attempt2 = []
    attempt3 = []
    attempt4 = []

    match_buf_attempt1 = match_buf_raw
    attempt1 = extract_document_around_offset_attempt(match_buf_raw, match_buf_attempt1)

    match_buf_attempt2 = match_buf_raw.replace(b'"\x00,', b'",')
    if match_buf_attempt2 != match_buf_attempt1:
        attempt2 = extract_document_around_offset_attempt(match_buf_raw, match_buf_attempt2)
        #if attempt2:
        #    print("############ FOUND DOCUMENT AFTER REMOVING \x00 BETWEEN \", ")

    match_buf_attempt3 = bytearray(b'')
    for i in range(0, len(match_buf_raw)):
        if i > 0:
            if i < len(match_buf_raw) - 1:
                if match_buf_raw[i] == ord('\x00') and chr(match_buf_raw[i - 1]) in zero_neighbors and chr(match_buf_raw[i + 1]) in zero_neighbors:
                    continue
        match_buf_attempt3.append(match_buf_raw[i])
    if match_buf_attempt3 != match_buf_attempt1:
        attempt3 = extract_document_around_offset_attempt(match_buf_raw, match_buf_attempt3)
        #if attempt3:
        #    print("############ FOUND DOCUMENT AFTER REMOVING \x00 IN DOCUMENT KEYS")

    if False:
        match_buf_attempt4 = bytearray(b'')
        for i in range(0, len(match_buf_raw)):
            if i > 0:
                if i < len(match_buf_raw) - 1:
                    if match_buf_raw[i] == ord('\x00') and chr(match_buf_raw[i - 1]) in '_key' and chr(match_buf_raw[i + 1]) in '_key':
                        continue
            match_buf_attempt4.append(match_buf_raw[i])
        if match_buf_attempt4 != match_buf_attempt1:
            attempt4 = extract_document_around_offset_attempt(match_buf_raw, match_buf_attempt4)

    result = sorted(set(attempt1 + attempt2 + attempt3))
    return result


def extract_document_around_offset_attempt(match_buf_raw, match_buf):
    possible_document_end_offsets = []
    for i in range((int)(len(match_buf)/8), len(match_buf)):
        if match_buf[i] == ord('}'):
            possible_document_end_offsets.append(i+1)

    if len(possible_document_end_offsets) == 0:
        return []

    possible_document_start_offsets = []
    for i in range(0, (int)(len(match_buf)/2)+256):
        if match_buf[i] == ord('{') and match_buf[i+1] == ord('"') \
                and match_buf[i+2] == ord('_') and match_buf[i+3] == ord('k'):
            for k in range(2, 12):
                possible_document_start_offsets.append(max(0, i-k))

    if len(possible_document_start_offsets) == 0:
        return []


    documents_found = []
    last_found_end_offset = possible_document_start_offsets[0]
    for i in possible_document_start_offsets:
        for end_offset in possible_document_end_offsets:
            if end_offset < last_found_end_offset:
                continue

            if end_offset <= i or end_offset-i > 1024 or end_offset-i < 256:
                continue

            match_string = hex_escape(match_buf[i:end_offset])

            if match_string.count("_key") > 1:
                continue

            try:
                doc = snappy.decompress(match_buf[i:end_offset])
                try:
                    d = json.loads(doc)
                    documents_found.append(doc)
                    last_found_end_offset = end_offset
                except UnicodeDecodeError as e:
                    #print(f"Skipped invalid JSON document: {doc}")
                    pass
                except json.decoder.JSONDecodeError as e:
                    #print(f"Skipped invalid JSON document: {doc}")
                    pass
            except:
                # Also maybe the JSON is not even compressed - check for that:
                first_brace = match_buf[i:end_offset].find(ord('{'))
                if first_brace > i and first_brace < end_offset - 2:
                    try:
                        maybe_json = hex_escape(match_buf[i+first_brace:end_offset])
                        #print(f'************** PARSING AS RAW JSON: {maybe_json}')
                        d = json.loads(match_buf[i+first_brace:end_offset])
                        documents_found.append(match_buf[i+first_brace:end_offset])
                        last_found_end_offset = end_offset
                        print(f'************** FOUND UNCOMPRESSED JSON RECORD AT OFFSET {i+first_brace}')
                    except UnicodeDecodeError as e:
                        #print(f"************* Skipped invalid JSON document: {maybe_json}")
                        pass
                    except json.decoder.JSONDecodeError as e:
                        #print(f"************* Skipped invalid JSON document: {maybe_json}")
                        pass
                    except Exception as e:
                        print(f"************* Skipped invalid JSON document: {maybe_json} {e}")
                    # print(f'\nnFAILED TO DECOMPRESS RANGE: {hex_escape(match_buf_raw[i-16:end_offset+16])}\n\n')
                # print(f'\n====FAILED TO DECOMPRESS RANGE:\n{match_string}\n\nRAW CONTEXT:\n{hex_escape(match_buf_raw[i-16:end_offset+16])}\n==================\n')
                pass

    return documents_found



def extract_document_around_offset_old(match_offset, raw_data, fileid = ''):
    raw_data.seek(max(0, match_offset)) #-window_half_size*2))
    #match_buf_raw = raw_data.read(window_half_size*4)
    match_buf_raw = raw_data.read(record_max_size+24)
    match_buf = match_buf_raw.replace(b'"\x00,', b'",')

    possible_document_end_offsets = []
    for i in range((int)(len(match_buf)/8), len(match_buf)):
        if match_buf[i] == ord('}'):
            possible_document_end_offsets.append(i+1)

    if len(possible_document_end_offsets) == 0:
        return []

    possible_document_start_offsets = []
    for i in range(0, (int)(len(match_buf)/2)+256):
        if match_buf[i] == ord('{') and match_buf[i+1] == ord('"') \
                and match_buf[i+2] == ord('_') and match_buf[i+3] == ord('k'):
            for k in range(2, 12):
                possible_document_start_offsets.append(max(0, i-k))

    if len(possible_document_start_offsets) == 0:
       return []


    documents_found = []
    last_found_end_offset = possible_document_start_offsets[0]
    for i in possible_document_start_offsets:
        for end_offset in possible_document_end_offsets:
            if end_offset < last_found_end_offset:
                continue

            if end_offset <= i or end_offset-i > 1024 or end_offset-i < 256:
                continue

            match_string = hex_escape(match_buf[i:end_offset])

            if match_string.count("_key") > 1:
                continue

            try:
                doc = snappy.decompress(match_buf[i:end_offset])
                try:
                    d = json.loads(doc)
                    documents_found.append(doc)
                    last_found_end_offset = end_offset
                except UnicodeDecodeError as e:
                    #print(f"Skipped invalid JSON document: {doc}")
                    pass
                except json.decoder.JSONDecodeError as e:
                    #print(f"Skipped invalid JSON document: {doc}")
                    pass
            except:
                # Remove zeros from printable sequences in the range
                match_buf_cleaned = []
                it = i

                cleaned = False

                for b in match_buf[i:end_offset]:
                    if b == 0 and it > i+1 and it < end_offset-2 \
                            and (chr(match_buf[it-1]) in '_keyscopefilemetamutators') and (chr(match_buf[it+1]) in '_keyscopefilemetamutators'):
                        cleaned = True
                        pass
                    else:
                        match_buf_cleaned.append(b)

                    it += 1

                if cleaned:
                    try:
                        print(f'%%%%%%%% TRYING TO DECOMPRESS RECORD AFTER REMOVING \x00 BETWEEN PRINTABLE CHARACTERS IN RANGE {match_string}')
                        doc = snappy.decompress(match_buf_cleaned)
                        print(f'%%%%%%%% FOUND VALID RECORD AFTER REMOVING \x00 BETWEEN PRINTABLE CHARACTERS IN RANGE {match_string}')
                        try:
                            d = json.loads(doc)
                            documents_found.append(doc)
                            last_found_end_offset = end_offset
                        except UnicodeDecodeError as e:
                            pass
                        except json.decoder.JSONDecodeError as e:
                            pass
                    except:
                        #print(f'\nnFAILED TO DECOMPRESS RANGE: {hex_escape(match_buf_raw[i-16:end_offset+16])}\n\n')
                        #print(f'\n====FAILED TO DECOMPRESS RANGE:\n{match_string}\n\nRAW CONTEXT:\n{hex_escape(match_buf_raw[i-16:end_offset+16])}\n==================\n')
                        pass

    return documents_found


def reg_to_csv(doc, csv_dump):
    try:
        d = json.loads(doc)
        return reg_to_csv_json(d, csv_dump)
    except UnicodeDecodeError as e:
        print(f"Skipped invalid JSON document: {doc}")
    except json.decoder.JSONDecodeError as e:
        print(f"Skipped invalid JSON document: {doc}")

    return False

def reg_to_csv_json(d, csv_dump):
    if '_scope' not in d:
        return False

    scope = d.get('_scope')
    key = d.get('_key', '')
    parent_uuid = d.get('parent_uuid', '')
    name = d.get('name', '')
    if '\t' in name:
        name = name.replace("\t", "___TABULATOR___")
    size = 0
    file_type = d.get('type', '')
    mode = d.get('mode', '')
    timestamp = d.get('_timestamp', '')
    deleted = d.get('_deleted', '')
    version = d.get('_version', '')
    seq = d.get('_seq', '')


    csv_dump.write(f'{scope}\t{key}\t{file_type}\t{parent_uuid}\t{name}\t0\t{mode}\t{timestamp}\t{deleted}\t{version}\t{seq}\n')

    return True

def dir_to_csv(doc, csv_dump):
    try:
        d = json.loads(doc)
        return dir_to_csv_json(d, csv_dump)
    except UnicodeDecodeError as e:
        print(f"Skipped invalid JSON document: {doc}")
    except json.decoder.JSONDecodeError as e:
        print(f"Skipped invalid JSON document: {doc}")

    return False


def dir_to_csv_json(d, csv_dump):
    if '_scope' not in d:
        return

    scope = d.get('_scope')
    key = d.get('_key', '')
    parent_uuid = d.get('parent_uuid', '')
    name = d.get('name', '')
    if '\t' in name:
        name = name.replace("\t", "___TABULATOR___")
    size = 0
    file_type = d.get('type', '')
    mode = d.get('mode', '')
    timestamp = d.get('_timestamp', '')
    deleted = d.get('_deleted', '')
    version = d.get('_version', '')
    seq = d.get('_seq', '')

    csv_dump.write(f'{scope}\t{key}\t{file_type}\t{parent_uuid}\t{name}\t0\t{mode}\t{timestamp}\t{deleted}\t{version}\t{seq}\n')

    return True


def parse_snappy_range(buf):
    start = buf.find(b'\x80\x00')
    end = buf[start+2:].find(b'\x80\x00') + 4
    success = False
    if end-start > 64 and end-start < 1024 and end < len(buf):
        for i in range(0,32):
            for j in range(0,8):
                try:
                    fixed = fix_buf(buf[start+i:end+4-j])
                    #print(f"===== DECOMPRESSING: {hex_escape(fixed)}")
                    d = snappy.decompress(fixed)
                    # print(f"=== DECOMPRESSED WITH OFFSET AFTER |80 00|: {i}\n")
                    data = json.loads(d)
                    print(f"==[{i},{j}]= {data}\n")
                    success = True
                    break
                except Exception as e:
                    print(f"!!! DECOMPRESSING FAILED DUE TO: {e}")
                    pass

def fix_buf(buf):
    res = buf
    res = res.replace(b'"\x00,', b'",')
    res = res.replace(b',\x00"', b',"')
    return res


def try_decompress2(buf):
    for i in [6,7,5,8,4,9,3,2,1,0]:
        for j in [-2,1,-1,0,2,3,-3]:
            try:
                # print(f'== DECOMPRESSING RANGE {hex_escape(buf)}')
                d = snappy.decompress(buf[i:min(len(buf), len(buf)+j)])
                data = json.loads(d)
                # print(f"==[{i},{j}]= {data}\n")
                return data
            except:
                pass

    return False


def try_decompress(buf):
    success = False
    for i in range(0,12):
        for j in range(0,6):
            try:
                d = snappy.decompress(buf[i:len(buf)+3-j])
                data = json.loads(d)
                # print(f"==[{i},{j}]= {data}\n")
                return data
            except:
                pass

    return None


def process_chunk2(file_range):
    try:
        offset, range_end = file_range
        raw_data = open(file_path, 'rb')
        csv_dump = open(f'recovery-{offset}.csv', 'w')
        if offset == 0:
            csv_dump.write('#scope\tfile_id\ttype\tparent_uuid\tname\tsize\tmode\ttimestamp\tdeleted\tversion\tsequence\n')
        buf = None
        buf_size = 1024*1024*10
        report_counter = 0
        found_counter = 0
        while offset < range_end + buf_size:
            print(f'PROCESSING RANGE {offset}-{range_end+buf_size}')
            raw_data.seek(offset)
            buf = raw_data.read(buf_size)

            possible_docs = []
            reg_docs = []
            dir_docs = []
            last_match_end = 0
            previous_last_match_end = -1
            while last_match_end < len(buf) and previous_last_match_end != last_match_end:
                start = buf[last_match_end:].find(b'\x80\x00') + len(b'\x80\x00') + last_match_end
                end = buf[start:].find(b'\x80\x00') + start + len(b'\x80\x00')

                found = False
                if end-start > 64 and end-start < 1024 and end < len(buf):
                    inner_buf = buf[start:end]
                    res = try_decompress2(fix_buf(inner_buf))
                    if res:
                        possible_docs.append(res)
                        found = True

                    if not found:
                        for zero in re.finditer(b'\x00', inner_buf):
                            if zero and zero.start() < len(inner_buf)-1:
                                zero_start = zero.start()
                                buf_tmp = inner_buf[:zero_start]+inner_buf[zero_start+1:]
                                res = try_decompress2(buf_tmp)
                                if res:
                                    possible_docs.append(res)
                                    break

                # Break infinite loop just in case
                previous_last_match_end = last_match_end

                # Move to next block
                last_match_end = end

            for doc in possible_docs:
                if 'type' in doc and doc['type'] == 'REG':
                    # print(f"######### Found REG document {doc}")
                    reg_docs.append(doc)
                    found_counter += 1
                if 'type' in doc and doc['type'] == 'DIR':
                    # print(f"######## Found DIR document {doc}")
                    dir_docs.append(doc)
                    found_counter += 1
                else:
                    pass

            # Convert REG docs to CSV
            for reg_doc in reg_docs:
                reg_to_csv_json(reg_doc, csv_dump)

            # Convert DIR docs to CSV
            for dir_doc in dir_docs:
                dir_to_csv_json(dir_doc, csv_dump)

            csv_dump.flush()

            offset += buf_size - window_half_size

            report_counter += 1

            if report_counter == 1:
                print(f"Found {found_counter} records in total in range {file_range} [{(offset-file_range[0])/(file_range[1]-file_range[0])*100}%]")
                report_counter = 0



        raw_data.close()
        csv_dump.close()

        return found_counter
    except Exception as e:
        print(f"PROCESS FAILED WITH {e}")


def process_chunk(file_range):
    try:
        offset, range_end = file_range
        raw_data = open(file_path, 'rb')
        csv_dump = open(f'recovery-{offset}.csv', 'w')
        if offset == 0:
            csv_dump.write('#scope\tfile_id\ttype\tparent_uuid\tname\tsize\tmode\ttimestamp\tdeleted\tversion\tsequence\n')
        buf = none
        buf_size = 1024*1024*10
        report_counter = 0
        found_counter = 0
        raw_data.seek(offset)
        while offset < range_end + buf_size:
            print(f'PROCESSING RANGE {offset}-{range_end+buf_size}')
            buf = raw_data.read(buf_size)
            # Search for file 'REG' documents
            possible_docs = []
            reg_docs = []
            dir_docs = []


            reg_matches = search_buf_reg(buf, offset)
            if len(reg_matches) > 0:
                # print(f'Found REG matches at offsets: {reg_matches}')
                for match_offset in reg_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)

            # Search for file 'DIR' documents
            dir_matches = search_buf_dir(buf, offset)
            if len(dir_matches) > 0:
                #print(f'Found DIR matches at offsets: {dir_matches}')
                for match_offset in dir_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)

            # Search for file 'meta' documents
            meta_matches = [] #search_buf_file_meta(buf, offset)
            if len(meta_matches) > 0:
                #print(f'Found key matches at offsets: {key_matches}')
                for match_offset in meta_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)


            # Search for file 'meta' documents
            parent_uuid_matches = [] #search_buf_file_uuid(buf, offset)
            if len(meta_matches) > 0:
                #print(f'Found key matches at offsets: {key_matches}')
                for match_offset in parent_uuid_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)


            key_matches = search_buf_key(buf, offset)
            if len(key_matches) > 0:
                #print(f'Found key matches at offsets: {key_matches}')
                for match_offset in key_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)


            scope_matches = search_buf_scope(buf, offset)
            if len(scope_matches) > 0:
                #print(f'Found key matches at offsets: {key_matches}')
                for match_offset in scope_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)


            mod_matches =  [] #search_buf_mode(buf, offset)
            if len(mod_matches) > 0:
                #print(f'Found mode matches at offsets: {mod_matches}')
                for match_offset in mod_matches:
                    docs = extract_document_around_offset(match_offset, raw_data)
                    if docs:
                        possible_docs.extend(docs)

            for doc in possible_docs:
                if b'REG' in doc:
                    #print(f"######### Found REG document {doc}")
                    reg_docs.append(doc)
                    found_counter += 1
                elif b'DIR' in doc:
                    #print(f"######## Found DIR document {doc}")
                    dir_docs.append(doc)
                    found_counter += 1
                else:
                    pass

            # Convert REG docs to CSV
            for reg_doc in reg_docs:
                reg_to_csv(reg_doc, csv_dump)

            # Convert DIR docs to CSV
            for dir_doc in dir_docs:
                dir_to_csv(dir_doc, csv_dump)

            csv_dump.flush()

            offset += buf_size - window_half_size
            raw_data.seek(offset)

            report_counter += 1

            if report_counter == 1:
                print(f"Found {found_counter} records in total in range {file_range} [{(offset-file_range[0])/(file_range[1]-file_range[0])*100}%]")
                report_counter = 0


        raw_data.close()
        csv_dump.close()

        return found_counter
    except Exception as e:
        print(f"PROCESS FAILED WITH {e}")


if __name__ == '__main__':
    with Pool(process_count) as p:
        chunks = [(start*(chunk_size_per_process-2048), min(file_size, (start+1)*chunk_size_per_process)) for start in range(0, chunk_count-1)]
        #skipped_chunks = []
        #for chunk in chunks:
        #    if chunk[1] > file_start_offset:
        #        skipped_chunks.append(chunk)
        print(f'Processing file in parallel in chunks: {chunks}')
        found = p.map(process_chunk2, chunks)
        print(f'Completed - found {found} in searched ranges')
        print(f'Total {sum(found)} records were found')

