# SPDX-FileCopyrightText: 2025 Jeff Epler
#
# SPDX-License-Identifier: GPL-3.0-only

import sys
import pathlib

if len(sys.argv) != 4:
    raise SystemExit(f"Usage: interl5.py stub.bin filename.z5 filename.dsk")

with open(sys.argv[1], "rb") as f:
    stub = f.read()
with open(sys.argv[2], "rb") as f:
    zcode = f.read()

if len(stub) != 16384:
    raise SystemExit(f"stub {sys.argv[1]} has unexpected size (expected 16384, got {len(stub)})")

def make_gcr():
    for i in range(128):
        b = f"{i:07b}"                 # All these tests apply to the low 7 bits
        if '000' in b: continue        # 3 consecutive zeros disallowed
        if b.count('00') > 1: continue # only one pair of consecutive 0s allowed
        if not '11' in b: continue     # one pair of consecutive 1s required 
        yield i | 128                  # top bit is always a 1
gcr = list(make_gcr())

twobit_count = 86
oneside = 35*16
interleave = [0, 7, 0xe, 6, 0xd, 5, 0xc, 4, 0xb, 3, 0xa, 2, 0x9, 1, 8, 0xf]

def togcr(trackno, data):
    result = bytearray(b"\xff" * 6656)
    index = 0
    def put(*values):
        nonlocal index
        for value in values:
            result[index] = value
            index += 1
    put(0xd5, 0xaa, 0xad) 

    # 3/4 = gcr4,4 encoding of 0-based track number 
    put((trackno >> 1) | 0xaa, trackno | 0xaa)

    # Now, 18 times: 86 'twobit' + 256 'sixbit' areas

    for s in range(18):
        checksum = 0
        d = data[s*256:(s+1)*256+2] + b'\0\0'
        print(d[:2], len(d))
        for i in range(256 + twobit_count):
            if i < twobit_count:
                tmp = d[i]            
                value = ((tmp & 1) << 1) | ((tmp & 2) >> 1)

                tmp = d[i + twobit_count]
                value |= ((tmp & 1) << 3) | ((tmp & 2) << 1)

                j = i + 2 * twobit_count 
                tmp = d[j]
                value |= ((tmp & 1) << 5) | ((tmp & 2) << 3)
                # C implementation of interlz5 has inconsistent behavior for
                # the last 2 bits which makes it difficult to exactly reproduce
            else:
                value = d[i - twobit_count] >> 2

            checksum ^= value
            put(gcr[checksum])
            #put(value)
            checksum = value
        put(gcr[checksum])

    # Remainder pre-filled with 0xff already!
    return result

empty_block = b"\0" * 256
disk1 = {}
disk2 = {}
def put_linear_sector(blocks, l, data):
    print("pld", l, data[:2])
    if len(data) < 256:
        data = [data + empty_block][:256]
    blocks[l] = data

for i in range(64):
    put_linear_sector(disk1, i, stub[i*256:(i+1)*256])

side1_limit = 394
for i in range(len(zcode)//256):
    if i < side1_limit:
        l = 64 + (i & 0xff0) + interleave.index(i & 0xf)
        put_linear_sector(disk1, l, zcode[i*256:(i+1)*256])
    else:
        l = i - side1_limit
        print("sideb", l)
        put_linear_sector(disk2, l, zcode[i*256:(i+1)*256])

with open(sys.argv[3], "wb") as f:
    for i in range(oneside):
        f.write(disk1.get(i, empty_block))

if disk2:
    print(disk2.keys())
    sidea = pathlib.Path(sys.argv[3])
    sideb = sidea.with_stem(sidea.stem.removeprefix("-sidea") + "-sideb").with_suffix(".nib")
    with open(sideb, "wb") as f:
        for i in range(35):
            trkdata = b"".join(disk2.get(i * 18 + j, empty_block) for j in range(19))
            f.write(togcr(i, trkdata))

    with open(sideb.with_suffix(".raw"), "wb") as f:
        for i in range(35):
            trkdata = b"".join(disk2.get(i * 18 + j, empty_block) for j in range(19))
            f.write(trkdata)

