import math
import json 
from pathlib import Path 
from PIL import Image

mapFolder = Path("/Volumes/Dean/My Code/Spectrum/Z80/Graphics/")
mapFilename = "map.json"
tileSheetFilename = "tiles.png"

maxTilesPerSet = 5

tileWidthPX = 16		# Width of each tile in pixels
tileHeightPX = 16		# Height of each tile in pixels
tileSheetWidth = 16		# Number of tile columns in the tilesheet
tileSheetHeight = 3		# Number of tile rows in the tileset

# Helper functions ----------------------------------------------------------

def dumpHex(p_rowArray, p_sep = ""):
	s = ""
	for i in range(0, len(p_rowArray)):
		s +=  "{:02X}{:s}".format(p_rowArray[i], p_sep)
	return s.rstrip()

def findTileset(p_tileSet, p_Row):
	for key in p_tileSet:
		rowArray = p_tileSet[key]["rows"]
		if(p_Row in rowArray):
			return p_tileSet[key]["number"]

# The Tilesheet classes -----------------------------------------------------

class TileSheet:

	tileData = []
	tileWidthPX = None
	tileHeightPX = None
	tileSheetWidth = None
	tileSheetHeight = None

	def __getTile(self, px, p_X, p_Y, p_W, p_H):
		ox = p_X * p_W
		oy = p_Y * p_H
		dataArray = []
		for y in range(0, p_H):
			data = 0b0
			for x in range(0, p_W):
				r, g, b = px[ox + x, oy + y]	
				bit = 1 << (p_W - x - 1)
				if(r | g | b) != 0: data |= bit
			dataArray.append(data)
		return dataArray

	def __init__(self, p_Filename, p_TileSheetWidth, p_TileSheetHeight, p_TileWidthPX, p_TileHeightPX):
		self.tileSheetWidth = p_TileSheetWidth
		self.tileSheetHeight = p_TileSheetHeight
		self.tileWidthPX = p_TileWidthPX
		self.tileHeightPX = p_TileHeightPX

		img = Image.open(p_Filename)
		width, height = img.size
		px = img.load()

		for y in range(0, self.tileSheetHeight):
			for x in range(0, self.tileSheetWidth):
				data = self.__getTile(px, x, y, self.tileWidthPX, self.tileHeightPX)
				self.tileData.append(data)

	def getAsm(self, p_Label, p_TileArray, p_maxTilesPerSet):
		formatString = "{{:0{:d}b}}".format(self.tileWidthPX)
		l = len(p_TileArray)
		for y in range(0, self.tileHeightPX):
			asm = (p_Label + ":\t\t") if y == 0 else "\t\t\t"
			sep = "DG "
			for i in range(0, p_maxTilesPerSet):
				t = p_TileArray[i] if i < l else 0
				bitData = self.tileData[t-1][y]
				asm += sep + formatString.format(bitData).replace("0", "-")
				sep = " "
			print(asm)

# The Tileset classes -------------------------------------------------------

class TileSets:

	values = None

	def __init__(self):
		self.values = {}

# The Map class -------------------------------------------------------------

class MapRow:
	rowNumber = None
	tiles = None
	
	def __init__(self, p_rowNumber, p_Tiles):
		self.tiles = p_Tiles
		self.rowNumber = p_rowNumber

	def getHash(self):
		s = sorted(set(self.tiles))
		return s if s[0] != 0 else s[1:]

	def getConvertedRow(self):
		tileSet = self.getHash()
		newRow = []
		for i in self.tiles:
			newRow.append(0 if i == 0 else tileSet.index(i) + 1)
		return newRow

	def dump(self):
		return "{:03d}: {:s} .. # {:s}".format(self.rowNumber, dumpHex(self.getConvertedRow(), " "), dumpHex(self.getHash()))

# Main ----------------------------------------------------------------------

def main():
	tileSheet = TileSheet(Path(mapFolder, tileSheetFilename), tileSheetWidth, tileSheetHeight, tileWidthPX, tileHeightPX)

	with open(mapFolder / mapFilename, 'r') as f:
		map = json.load(f)

	mapHeight = map["height"]
	mapWidth = map["width"]
	mapLayers = map["layers"]
	mapData = mapLayers[0]["data"]

	# Process the map - first pass to get the data into the mapRows array
	#
	mapRows = []
	i = 0
	for y in range(0, mapHeight):
		mapRow = MapRow(y, mapData[i : i + mapWidth])
		mapRows.append(mapRow)
		i += mapWidth
#		print(mapRow.dump())
		if len(mapRow.getHash()) > maxTilesPerSet:
			print("^^^  Error - maxTilesPerSet exceeded")
			break

	# Get the list of unique hashes 
	#
	tileSets = dict()
	i = 0
	for mapRow in mapRows:
		hashArray = mapRow.getHash()
		hashHex = dumpHex(hashArray)
		rowNumber = mapRow.rowNumber
		if hashHex in tileSets:
			tileSets[hashHex]["rows"].append(rowNumber)
		else:
			tileSets[hashHex] = { "number": i, "hash": hashArray, "rows": [ rowNumber ] }
			i += 1

	# Dump out the tilesets
	#
	for key in tileSets:
		tileSheet.getAsm("Tileset_{:02d}".format(tileSets[key]["number"]), tileSets[key]["hash"], maxTilesPerSet)
		print("")

	# Loop through the mapRows and print out the data 
	#
	for mapRow in mapRows:
		rowArray = mapRow.getConvertedRow()
		rowStr = ",".join(str(x) for x in rowArray)
		tileSetNumber = findTileset(tileSets, mapRow.rowNumber)
		print("\t\t\tMAP_ROW Tileset_{:02d}, {:s}".format(tileSetNumber, rowStr))

main()