import numpy as np
import os

BLOAD_HEADER = ('bload_header', (np.uint8, 7))
BLOAD_SIZE = 7

H = np.int16
def fieldlist_to_dtype (fields):
    fields = fields.split()
    return [(field,H) for field in fields]

def mmap (fname, dtype, offset = 0, shape = None):
    return np.memmap (fname, dtype = dtype, mode = 'r', offset = offset, shape = shape)

def from_fieldnames (*allnames, **corrections):
    names = []
    formats = []
    for somenames in allnames:
        somenames = somenames.split()
        names.extend (somenames)
        formats.extend ([H] * len (somenames))
    for name, format in corrections.items():
        formats[names.index(name)] = format
    return {'names': names, 'formats': formats}
        
make = from_fieldnames

def fvstr (maxlen):
    return [('length', H), ('value', (np.character, maxlen))]

# some formats have all the x, then all the y, ..
def planar_dtype (fieldnames, num, dtype, bload = False):
    dt = []
    if bload:
        dt.append (BLOAD_HEADER)
    for field in fieldnames.split():
        dt.append ((field, (dtype, num)))
    return np.dtype (dt)

def fix_stringjunk (arr, fields = None, doubled = False):
    fields = fields or arr.dtype.names
    for name in fields:
        for item in arr[name]:
            if doubled:
                tmp = str(item['value'])[:item['length'] * 2]
                tmp = list(tmp)
                tmp[1::2] = ['.'] * (len (tmp) / 2)
                tmp = "".join (tmp)
            else:
                tmp = str(item['value'])[:item['length']]
            item['value'] = tmp

def pad (filename, granularity, groupsize = 1, headersize = 0):
    filesize = os.path.getsize (filename)
    nrecordsets = (filesize - headersize) / float (granularity * groupsize)
    padding = ''
    fullrecordsize = (int (nrecordsets) * (granularity * groupsize)) + headersize
    if fullrecordsize < filesize:
        nrecordsets = round (nrecordsets + 0.499999999999)
        padding = '\x00' * (int (nrecordsets) * (granularity * groupsize) - filesize)
    if padding == '':
        return int (nrecordsets) / groupsize
    f = open (filename, 'ab')
    f.seek (1, 2)
    f.write (padding)
    f.close ()
    filesize = os.path.getsize (filename)
    nrecords = (filesize - headersize) / granularity
    return nrecords

class fixBits (object):
    fields = ['attackitems', 'weappoints', 'stuncancel', 'defaultdissolve', 
              'defaultdissolveenemy', 'pushnpcbug_compat', 'default_maxitem',
              'blankdoorlinks', 
              'shopsounds', 'extended_npcs', 'heroportrait', 'textbox_portrait',
              'npclocation_format', ]
    def __init__ (self, filename, **kwargs):
        self.file = open (filename, 'rwb+')
        for k, v in kwargs.items():
            setattr (self, k, v)
    def save (self, f):
        self.file.seek (0)
        f.write (self.file.read())
    def tostring (self):
        self.file.seek (0)
        return self.file.read()
    def __getitem__ (self, k):
        return self.__getattr__ (self.fields[k])
    def __setitem__ (self, k, v):
        self.__setattr__ (self.fields[k], v)
    def __getattr__ (self, k):
        try:
            k = object.__getattribute__ (self, 'fields').index (k)
        except ValueError:
            return object.__getattribute__ (self, k)
        self.file.seek (k / 8)
        result = ord (self.file.read(1)) & (1 << k % 8)
        if result > 0:
           return 1
        else:
           return 0
    def __setattr__ (self, k, v):
        try:
            k = object.__getattribute__ (self, 'fields').index (k)
        except ValueError:
            object.__setattr__ (self, k, v)
            return
        self.file.seek (k / 8)
        value = ord (self.file.read (1))
        if value & (1 << k % 8):
            value ^= (1 << k % 8)
        if v:
            value |= (1 << k % 8)
        self.file.seek (k / 8)
        self.file.write (chr (value))
    def __repr__ (self):
        kwargs = ", ".join (['%s = %d' % (name, v) for name, v in zip (self.fields, self)])
        return "%s (%r, %s)" % (self.__class__.__name__, self.file.name, kwargs) 
    def __iter__ (self):
        return [getattr (self, k) for k in self.fields].__iter__()
    def __gc__ (self):
        self.file.close()

class archiNym (object):
    def __init__ (self, filename, **args):
        self.file = open (filename, 'rwb+')
        for k, v  in enumerate (args):
            self[k] = v
    def __getitem__ (self, k):
        assert (-1 < k < 2)
        self.file.seek (0)
        for i in range (k):
            self.file.readline()
        return self.file.readline().rstrip()
    def __setitem__ (self, k, v):
        assert (-1 < k < 2)
        everything = [self[0], self[1]]
        self.file.seek (0)
        everything [k] = v
        for value in everything:
            self.file.write (value + '\x0d\x0a')
    def __repr__ (self):
        return '%s (%r, %r, %r)' % (self.__class__.__name__, self.file.name, 
                                    self[0], self[1])
    def _getprefix (self):
        return self[0]
    def _setprefix (self, v):
        self[0] = v
    def _getversion (self):
        return self[1]
    def _setversion (self, v):
        self[1] = v
    prefix = property (_getprefix, _setprefix)
    version = property (_getversion, _setversion)


def set_str16 (dest, src): # 16bit/8bit len, 16bit chars
    assert len (src) <= (len (dest['value']) / 2)
    dest['length'] = len(src)
    dest['value'] = "".join([char + '.' for char in src] + ['\x00\x00'])

def set_str8 (dest, src): # 16 or 8bit header, 8bit chars
    assert len (src) <= len (dest['value'])
    dest['length'] = len(src)
    dest['value'] = src + '\x00'

def get_str16 (src):
    return src['value'][::2]

def get_str8 (src):
    return src['value']

# set_str?? zero 'junk' bytes (happens automatically courtesy of numpy.)

def adjust_for_binsize (dtype, binsize):
    dtsize = np.dtype (dtype).itemsize
    while dtsize > binsize:
        dtype['names'].pop()
        dtype['formats'].pop()
        dtsize = np.dtype (dtype).itemsize
    if dtsize != binsize:
        raise ValueError ('dtype is misaligned with binsize!')

def vstr (len):
    "dtype of an OHRRPGCE string (BYTE length, BYTE-based characters) totaling ``len`` bytes"
    return [('length', np.uint8),('value', (np.character, len - 1))]

def vstr2 (len):
    "dtype of an OHRRPGCE string (SHORT length, SHORT-based characters) totaling ``len`` bytes"
    return [('length', np.uint16),('data', (np.character, (len - 1) * 2))]

_statlist = 'hp mp str acc def dog mag wil spd ctr foc xhits'.split()
STATS_DTYPE = [(name, H) for name in _statlist]
STATS0_99_DTYPE = [(name, (H, 2)) for name in _statlist]
xycoord_dtype = [('x', H), ('y',H)]
_browse_base_dtype = [('length', H), ('value', (np.character, 38))]

ptshapes = ((32,40,8), (34, 34, 1), (50, 50, 1), (80, 80, 1), (20, 20, 8), 
            (24, 24, 2), (50, 50, 3), (16, 16, 16), (50, 50, 1))

dtypes = {
    '_attack' : make ('picture palette animpattern targetclass targetsetting',
                         'damage_eq aim_math baseatk_stat cost xdamage chainto chain_percent',
                         'attacker_anim attack_anim attack_delay nhits target_stat',
                         'preftarget bitsets1 name captiontime caption basedef_stat',
                         'settag tagcond tagcheck settag2 tagcond2 tagcheck2 bitsets2',
                         'description consumeitem nitems_consumed soundeffect',
                         'stat_preftarget',
                         bitsets1 = ('B', 64 / 8), bitsets2 = ('B', 128 / 8),
                         cost = [('hp', H), ('mp', H), ('money', H)], 
                         name = [('length', H), ('unused', H), ('value', (H, 10))],
                         caption = fvstr (38),
                         description = fvstr (38)),
    'attack.bin' : make ('captionpt2 basedef_stat',
                         'settag tagcond tagcheck settag2 tagcond2 tagcheck2 bitsets2',
                         'description consumeitem nitems_consumed soundeffect',
                         'stat_preftarget',
                         bitsets2 = ('B', 128 / 8), 
                         captionpt2 = 'S36',
                         description = fvstr (38),
                         consumeitem = (H, 3),
                         nitems_consumed = (H, 3)),
    'binsize.bin' : make ('attack.bin stf songdata.bin sfxdata.bin map',
                          'menus.bin menuitem.bin uicolors.bin say'),
    'browse.txt' : [('longname', _browse_base_dtype), ('about', _browse_base_dtype)],
    'defpass.bin' : [('passability', (H, 160)), ('magic', H)],
    'd' : planar_dtype ('srcdoor destdoor destmap condtag1 condtag2', 100, H),
    'dt1' : make ('name thievability stealable_item stealchance',
                  'raresteal_item raresteal_chance dissolve dissolvespeed',
                  'deathsound unused picture palette picsize rewards stats',
                  'bitsets spawning attacks unused2',
                  name = vstr2 (17), 
                  rewards = make ('gold exp item itemchance rareitem rareitemchance'),
                  bitsets = ('B', 10),
                  spawning = make ('death non_e_death alone non_e_hit',
                                   'elemhit n_tospawn',
                                   elemhit = (H, 8)),
                  attacks = [('regular', (H, 5)), ('desperation', (H, 5)),
                             ('alone', (H, 5)), ('counter', (H, 8))],
                  stats = STATS_DTYPE,
                  unused = (H, 28), 
                  unused2 = (H, 45)),
    'dt6' : make ('picture palette animpattern targetclass targetsetting',
                  'damage_eq aim_math baseatk_stat cost xdamage chainto chain_percent',
                  'attacker_anim attack_anim attack_delay nhits target_stat',
                  'preftarget bitsets1 name captiontime captionpt1',
                  bitsets1 = ('B', 64 / 8), 
                  cost = [('hp', H), ('mp', H), ('money', H)],
                  name = [('length', H), ('unused', H), ('value', 'S20')],
                  captionpt1 = fvstr (4)),
    'efs' : [('frequency', H),('formations',(H, 20)), ('wasted', (H, 4))],
    'for' : make ('enemies background music backgroundframes backgroundspeed unused',
                  enemies = (make ('type x y unused'), 8), unused = (H, 4)),
    'map' : make ('tileset music minimap_available save_anywhere display_name_time',
                  'edge_mode edge_tile autorun_trigger autorun_arg harmtile_damage',
                  'harmtile_flash foot_offset afterbattle_trigger',
                  'insteadofbattle_trigger each_step_trigger keypress_trigger draw_herosfirst',
                  'npcanddoor_loading tileandwall_loading bitsets savoffset layer_tilesets',
                  'n_npc_instances', savoffset = xycoord_dtype, bitsets = ('B', 2),
                  layer_tilesets = (H, 3)),
    'mxs' : [('planes', (np.uint8, (4, 16000)))],
    'menuitem.bin' : make ('membership caption sort_order type subtype',
                           'tagcond1 tagcond2 settag toggletag bitsets extra',
                           extra = (H,3), bitsets = ('B',2), caption = fvstr (38)),
    'veh' : make ('name speed bitsets randbattles usebutton menubutton ridingtag onmount',
                  'ondismount overridewalls blockedby mountfrom dismount_to elevation reserved',
                  name = vstr (16), bitsets = (np.uint8, 4), reserved = (H, 18)),
    'palettes.bin' : [('color', ([('r', np.uint8,), ('g', np.uint8), ('b', np.uint8)], 256))],
    'sfxdata.bin' : [('name', fvstr (30)), ('streaming', H)],
    'songdata.bin' : [('name', fvstr (30))],
    'til' : [('planes', (np.uint8, (4, 16000)))],
    'tmn' : vstr2 (21),
    }

del _browse_base_dtype

for i, data in enumerate (ptshapes):
    w, h, frames = data
    dtypes['pt%d' % i] = [('pixels', (np.uint8, (w/2) * h * frames))]

del w, h, frames, i, data

#>>> fields += fieldlist_to_dtype ('randbattles usebutton menubutton ridingtag')
#>>> fields += fieldlist_to_dtype ('onmount ondismount overridewalls blockedby')
#>>> fields += fieldlist_to_dtype ('mountfrom dismount_to elevation')
#>>> fields += [('reserved', (H, 18))]


