import sys
import struct
import os
import os.path
import numpy as np
import imageio
# The tileset files well0000.png etc. must exist in *png* format
# in the following directory.
# sys.path[0] gives the directory that contains the python script.
tileset_dir = os.path.join(sys.path[0], '..', 'data', 'extracted_art.vol')
si4 = struct.Struct('<i') # 4 byte integer, little endian
si2 = struct.Struct('<h') # 2 byte integer, little endian
class Map:
    def __init__(self, data):
        self.data = data
        self.index = 0
        self.N = len(data)
        self.read_all()
    # static
    def read_file(filename):
        with open(filename, 'rb') as stream:
            data = stream.read()
        return Map(data)
    ## Read a .map file ##
    def read_all(self):
        self.read_header()
        self.read_map()
        self.read_clip_region()
        self.read_tileset_names()
        self.read_tiles()
    ## Helper functions for reading a .map file ##
    def next_k(self, k):
        b = self.data[self.index : self.index + k]
        self.index += k
        assert self.index <= self.N
        return b
    def read_int4(self):
        int4, = si4.unpack(self.next_k(4))
        return int4
    def read_int2(self):
        int2, = si2.unpack(self.next_k(2))
        return int2
    def read_byte(self):
        return self.next_k(1)[0]
    def read_header(self):
        assert self.read_int4() >= 0x1010
        self.read_int4()
        self.width = 2 ** self.read_int4()
        self.height = self.read_int4()
        self.num_tilesets = self.read_int4()
    def read_map(self):
        start = self.index
        end = self.index + 4 * self.width * self.height
        assert end <= self.N
        raw_map = np.ndarray(shape = (self.width * self.height,),
                dtype = '<u4', # little endian
                buffer = self.data[start : end])
        self.map = np.zeros((self.width, self.height), dtype = raw_map.dtype)
        for i in range(0, self.width, 32):
            col = raw_map[i * self.height : (i + 32) * self.height]
            self.map[i : i + 32, :] = col.reshape((32, self.height), order = 'F')
        self.map_tile_index = (self.map >> 5) & ((2 ** 11) - 1)
        self.index = end
    def read_clip_region(self):
        self.read_int4()
        self.read_int4()
        self.read_int4()
        self.read_int4()
    def read_tileset_names(self):
        self.tileset_filenames = [None] * self.num_tilesets
        self.tileset_lengths = [0] * self.num_tilesets
        self.num_valid_tilesets = 0
        for i in range(self.num_tilesets):
            k = self.read_int4()
            if k > 0:
                filename_base = self.next_k(8).decode('ascii')
                self.tileset_filenames[i] = filename_base + '.png'
                self.tileset_lengths[i] = self.read_int4()
                self.num_valid_tilesets += 1
        for i in range(self.num_valid_tilesets):
            assert self.tileset_lengths[i] > 0
        self.tileset_filenames = self.tileset_filenames[:self.num_valid_tilesets]
        self.tileset_lengths = self.tileset_lengths[:self.num_valid_tilesets]
    def read_tiles(self):
        assert self.next_k(10) == b'TILE SET\x1a\x00'
        self.num_tiles = self.read_int4()
        self.tiles = [None] * self.num_tiles
        for i in range(self.num_tiles):
            tileset = self.read_int2()
            tileindex = self.read_int2()
            self.next_k(4) # animation data
            self.tiles[i] = (tileset, tileindex)
    def create_image(self, z):
        assert z in [1, 2, 4, 8, 16, 32]
        t = Tilesets(self)
        palette = t.prepare_palette(z)
        image = np.zeros((self.height * z, self.width * z, 4), dtype = t.dtype)
        for j in range(self.height):
            for a in range(z):
                image[j * z + a] = palette[a, self.map_tile_index[:, j], :, :].reshape((self.width * z, 4), order = 'C')
        return image
class Tilesets:
    def __init__(self, m):
        self.num_sets = m.num_valid_tilesets
        self.filenames = list(m.tileset_filenames)
        self.lengths = list(m.tileset_lengths)
        self.n = m.num_tiles
        self.tiles = list(m.tiles)
        self.read_tilesets()
    def read_tilesets(self):
        self.tilesets = [None] * self.num_sets
        for i in range(self.num_sets):
            k = self.lengths[i]
            data = imageio.imread(os.path.join(tileset_dir, self.filenames[i]))
            assert data.shape[:2] == (32 * k, 32)
            if len(data.shape) == 2:
                data = np.copy(np.broadcast_to(data[:, :, None], (32 * k, 32, 4)))
            self.tilesets[i] = data
            self.dtype = data.dtype
    def prepare_palette(self, z = 8):
        zoomed_tilesets = []
        for i in range(self.num_sets):
            zoomed_tilesets.append(zoom_image(self.tilesets[i], z))
        palette = np.zeros((z, self.n, z, 4), dtype = self.dtype)
        for i in range(self.n):
            w, idx = self.tiles[i]
            t = zoomed_tilesets[w]
            palette[:, i, :, :] = t[z * idx : z * (idx + 1), :, :]
        if False:
            palette_flat = np.zeros((z, self.n * z, 4), dtype = self.dtype)
            for i in range(z):
                palette_flat[i] = palette[i].reshape((self.n * z, 4), order = 'C')
            imageio.imwrite('palette.png', palette_flat)
        return palette
def zoom_image(data, z):
    if z == 32:
        return data
    assert z in [1, 2, 4, 8, 16, 32]
    r = 32 // z
    data1_ = np.cumsum(data, axis = 0)[r - 1 :: r]
    data1 = np.copy(data1_)
    data1[1:] -= data1_[:-1]
    data2_ = np.cumsum(data1, axis = 1)[:, r - 1 :: r]
    data2 = np.copy(data2_)
    data2[:, 1:] -= data2_[:, :-1]
    return data2
def run(pixels_per_tile, filename):
    assert os.path.isfile(filename)
    m = Map.read_file(filename)
    print(m.width, m.height)
    image = m.create_image(pixels_per_tile)
    imageio.imwrite('map.png', image)
if __name__ == "__main__":
    pixels_per_tile = int(sys.argv[1])
    filename = sys.argv[2]
    run(pixels_per_tile, filename)