#!/usr/bin/env python3

import os
import io
import sys
import re
import argparse
import json
import xml.etree.ElementTree as ET

# on msys, use crlf output
nl = None
if sys.platform == 'msys':
    nl = "\r\n"
    
# Get the file, relative to this script's location (same directory)
# that way we're not sensitive to CWD
pathname = os.path.abspath(os.path.dirname(sys.argv[0])) + os.path.sep

with open(pathname + 'spirv.core.grammar.json', mode='r') as f:
    spirv = json.load(f)

with open(pathname + 'extinst.glsl.std.450.grammar.json', mode='r') as f:
    glsl450 = json.load(f)

# open XML registry
registry = ET.parse(pathname + 'spir-v.xml').getroot()

# open the file for write
header = open(pathname + 'spirv_gen.h', mode='w', newline = nl)
ops_header = open(pathname + 'spirv_op_helpers.h', mode='w', newline = nl)
cpp = open(pathname + 'spirv_gen.cpp', mode='w', newline = nl)

###############################################################################
##
## Headers
##
###############################################################################

def prefix_star(line):
    if line == '':
        return ' *'
    else:
        return ' * ' + line

def operand_name(name, lowercase_first = True):
    name = name.replace('\n', ' ')
    # special case a few very awkward names
    if re.search(r'member [0-9].*\.\.\.', name, re.RegexFlag.I):
        return 'members'
    if re.search(r'parameter [0-9].*\.\.\.', name, re.RegexFlag.I):
        return 'parameters'
    if re.search(r'argument [0-9].*\.\.\.', name, re.RegexFlag.I):
        return 'arguments'
    if re.search(r'variable, parent.*\.\.\.', name, re.RegexFlag.I):
        return 'parents'

    name = re.sub(r'<<(.*),(.*)>>', r'\2', name)
    name = re.sub(r'[ \'~<>./-]', '', name)

    if name.lower() == 'interface':
        return 'iface'

    if name.lower() == 'default':
        return 'def'

    if lowercase_first:
        return name[0].lower() + name[1:]
    else:
        return name

copyright = '''
/******************************************************************************
 * The MIT License (MIT)
 *
 * Copyright (c) 2020 Baldur Karlsson
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 ******************************************************************************/

/******************************************************************************
 * Generated from Khronos SPIR-V machine-readable JSON grammar.
 *
{}
 ******************************************************************************/
'''.format("\n".join([prefix_star(line.strip()) for line in spirv['copyright']])).strip()

header.write('''{copyright}

#pragma once

// This file is autogenerated with gen_spirv_code.py - any changes will be overwritten next time
// that script is run.
// $ ./gen_spirv_code.py

// We need to disable clang-format since this file is programmatically generated
// clang-format off

#include <stdint.h>
#include "api/replay/apidefs.h"
#include "api/replay/stringise.h"

#undef None
#undef CopyMemory
#undef MemoryBarrier

namespace rdcspv
{{
static const uint32_t MagicNumber = {magic};
static const uint32_t VersionMajor = {major};
static const uint32_t VersionMinor = {minor};
static const uint32_t VersionRevision = {revision};
static const uint32_t VersionPacked = ({major} << 16) | ({minor} << 8);
static const uint32_t OpCodeMask = 0xffff;
static const uint32_t WordCountShift = 16;
static const uint32_t FirstRealWord = 5;

struct Id
{{
  constexpr inline Id() : id(0) {{}}
  // only allow explicit functions to cast to/from uint32_t
  constexpr static inline Id fromWord(uint32_t i) {{ return Id(i); }}
  inline uint32_t value() const {{ return id; }}
  constexpr inline explicit operator bool() const {{ return id != 0; }}
  constexpr inline bool operator==(const Id o) const {{ return id == o.id; }}
  constexpr inline bool operator!=(const Id o) const {{ return id != o.id; }}
  constexpr inline bool operator<(const Id o) const {{ return id < o.id; }}
  constexpr inline bool operator==(const uint32_t o) const {{ return id == o; }}
  constexpr inline bool operator!=(const uint32_t o) const {{ return id != o; }}
  constexpr inline bool operator<(const uint32_t o) const {{ return id < o; }}
private:
  constexpr inline Id(uint32_t i) : id(i) {{}}
  uint32_t id;
}};

enum class Generator : uint32_t
{{'''.format(copyright = copyright, magic = spirv['magic_number'], major = spirv['major_version'], minor = spirv['minor_version'], revision = spirv['revision']))

generator_tostr = ''

for gen in registry.findall('ids[@type=\'vendor\']/id[@tool]'):
    name = operand_name(gen.attrib['tool'], lowercase_first=False)

    tostr = '{} from {} - {}'.format(gen.attrib['tool'], gen.attrib['vendor'], gen.attrib['comment'])

    generator_tostr += '    STRINGISE_ENUM_CLASS_NAMED({}, "{}");\n'.format(name, tostr.replace('"', '\\"').replace('\\', '\\\\'))
    header.write('\n  {} = {},'.format(name, gen.attrib['value']))

header.write('\n};\n\n')

ops_header.write('''{copyright}

#pragma once

// This file is autogenerated with gen_spirv_code.py - any changes will be overwritten next time
// that script is run.
// $ ./gen_spirv_code.py

// We need to disable clang-format since this file is programmatically generated
// clang-format off

#include <functional>
#include <set>
#include <stdint.h>
#include "api/replay/apidefs.h"
#include "api/replay/rdcstr.h"
#include "api/replay/rdcarray.h"
#include "api/replay/stringise.h"

#undef None
#undef CopyMemory
#undef MemoryBarrier

#include "spirv_common.h"
#include "spirv_gen.h"

namespace rdcspv
{{

template<typename Type>
Type DecodeParam(const ConstIter &it, uint32_t &word);

template<>
inline uint32_t DecodeParam(const ConstIter &it, uint32_t &word)
{{
  if(word >= it.size()) return 0;
  
  uint32_t ret = it.word(word);
  word += 1;
  return ret;
}}

template<>
inline Id DecodeParam<Id>(const ConstIter &it, uint32_t &word)
{{
  if(word >= it.size()) return Id();
  
  Id ret = Id::fromWord(it.word(word));
  word += 1;
  return ret;
}}

template<>
inline rdcstr DecodeParam<rdcstr>(const ConstIter &it, uint32_t &word)
{{
  if(word >= it.size()) return "";
  
  rdcstr ret = (const char *)&it.word(word);
  word += uint32_t(ret.size() / 4) + 1;
  return ret;
}}

template<typename Type>
rdcarray<Type> MultiParam(const ConstIter &it, uint32_t &word)
{{
  rdcarray<Type> ret;
  while(word < it.size())
  {{
    Type t = DecodeParam<Type>(it, word);
    ret.push_back(t);
  }}
  return ret;
}}

inline void EncodeParam(rdcarray<uint32_t> &words, const rdcstr &str)
{{
  size_t i=0, remainingChars = str.size() + 1;
  while(remainingChars > 0)
  {{
    uint32_t word = 0;
    for(size_t w=0; w < remainingChars && w < 4; w++)
      word |= uint32_t(str[i+w]) << (w*8);
    words.push_back(word);
    
    i += 4;
    if(remainingChars < 4)
      remainingChars = 0;
    else
      remainingChars -= 4;
  }}
}}

'''.format(copyright = copyright))

cpp.write('''{copyright}

// This file is autogenerated with gen_spirv_code.py - any changes will be overwritten next time
// that script is run.
// $ ./gen_spirv_code.py

// We need to disable clang-format since this file is programmatically generated
// clang-format off

#include "spirv_gen.h"
#include "os/os_specific.h"
#include "common/formatting.h"
#include "spirv_op_helpers.h"

'''.format(copyright = copyright))

###############################################################################
##
## Operands (declare enums, stringise, preprocess)
##
###############################################################################

positional_names = [ 'first', 'second', 'third' ]
kinds = {}

for operand_kind in spirv['operand_kinds']:
    name = operand_kind['kind']

    if 'enumerants' in operand_kind:
        operand_kind['has_params'] = any(['parameters' in value for value in operand_kind['enumerants']])
    else:
        operand_kind['has_params'] = False

    kinds[name] = operand_kind

    operand_kind['push_words'] = lambda name: 'words.push_back((uint32_t){});'.format(name)
    operand_kind['from_words'] = None
    operand_kind['is_id'] = False

    if operand_kind['category'] == 'ValueEnum':
        operand_kind['size'] = 1
        operand_kind['def_name'] = name[0].lower() + name[1:]
        operand_kind['def_value'] = name + '::Invalid'
        operand_kind['type'] = name

        decl = ''
        stringise = ''

        used = []

        for value in operand_kind['enumerants']:
            value_name = value['enumerant']
            if value_name[0].isdigit():
                value_name = '_' + value_name
            decl += '  {} = {},\n'.format(value_name, value['value'])

            if value['value'] in used:
                continue

            used.append(value['value'])

            if value_name != value['enumerant']:
                stringise += '    STRINGISE_ENUM_CLASS_NAMED({}, "{}");\n'.format(value_name, value['enumerant'])
            else:
                stringise += '    STRINGISE_ENUM_CLASS({});\n'.format(value_name)

        header.write('''enum class {name} : uint32_t
{{
{values}
  Max,
  Invalid = ~0U,
}};

'''.format(name = name, values = decl.rstrip()))

        cpp.write('''template <>
rdcstr DoStringise(const rdcspv::{name} &el)
{{
  BEGIN_ENUM_STRINGISE(rdcspv::{name});
  {{
{values}
  }}
  END_ENUM_STRINGISE();
}}

'''.format(name = name, values = stringise.rstrip()))
    elif operand_kind['category'] == 'BitEnum':
        operand_kind['size'] = 1
        operand_kind['def_name'] = name[0].lower() + name[1:]
        operand_kind['def_value'] = name + '::None'
        operand_kind['type'] = name

        used = []

        decl = ''
        stringise = ''
        for value in operand_kind['enumerants']:
            decl += '  {} = {},\n'.format(value['enumerant'], value['value'])

            if value['value'] in used:
                continue

            used.append(value['value'])

            if value['enumerant'] == 'None':
                stringise += '    STRINGISE_BITFIELD_CLASS_VALUE(None);\n\n'
            else:
                stringise += '    STRINGISE_BITFIELD_CLASS_BIT({});\n'.format(value['enumerant'])

        header.write('''enum class {name} : uint32_t
{{
{values}
  Max,
  Invalid = ~0U,
}};

BITMASK_OPERATORS({name});

'''.format(name = name, values = decl.rstrip()))

        cpp.write('''template <>
rdcstr DoStringise(const rdcspv::{name} &el)
{{
  BEGIN_BITFIELD_STRINGISE(rdcspv::{name});
  {{
{values}
  }}
  END_BITFIELD_STRINGISE();
}}

'''.format(name = name, values = stringise.rstrip()))
    # Hardcoded special types that we hardcode behaviour for
    elif (operand_kind['kind'] == 'IdRef'):
        operand_kind['size'] = 1
        operand_kind['def_name'] = 'id'
        operand_kind['def_value'] = 'Id()'
        operand_kind['type'] = 'Id'
        operand_kind['is_id'] = True
        operand_kind['push_words'] = lambda name: 'words.push_back({}.value());'.format(name)
        operand_kind['from_words'] = lambda name: 'Id::fromWord({})'.format(name)
    elif (operand_kind['kind'] == 'IdResultType' or
          operand_kind['kind'] == 'IdResult' or
          operand_kind['kind'] == 'IdMemorySemantics' or
          operand_kind['kind'] == 'IdScope'):
        operand_kind['size'] = 1
        operand_kind['type'] = name
        operand_kind['is_id'] = True
        operand_kind['def_name'] = name[2].lower() + name[3:]
        operand_kind['def_value'] = name + '()'
        operand_kind['push_words'] = lambda name: 'words.push_back({}.value());'.format(name)
        operand_kind['from_words'] = lambda name: 'Id::fromWord({})'.format(name)
        header.write('using {} = Id;\n\n'.format(name))
    # For simplicity, assume literal integers are 32-bit in size
    elif (operand_kind['kind'] == 'LiteralInteger'):
        operand_kind['size'] = 1
        operand_kind['def_name'] = 'num'
        operand_kind['def_value'] = '0'
        operand_kind['type'] = 'uint32_t'
    elif (operand_kind['kind'] == 'LiteralString'):
        operand_kind['size'] = -1000000
        operand_kind['type'] = 'rdcstr'
        operand_kind['def_name'] = 'str'
        operand_kind['def_value'] = '""'
        operand_kind['push_words'] = lambda name: 'EncodeParam(words, {});'.format(name)
        operand_kind['from_words'] = lambda name: 'DecodeParam({})'.format(name)
    elif (operand_kind['kind'] == 'LiteralContextDependentNumber' or
          operand_kind['kind'] == 'LiteralExtInstInteger' or
          operand_kind['kind'] == 'LiteralSpecConstantOpInteger'):
        operand_kind['size'] = None
    elif (operand_kind['kind'] == 'PairLiteralIntegerIdRef'):
        operand_kind['size'] = 2
        operand_kind['def_name'] = name[0].lower() + name[1:]
        operand_kind['def_value'] = '{0, Id()}'
        operand_kind['type'] = name
        operand_kind['push_words'] = lambda name: 'words.push_back((uint32_t){0}.first); words.push_back({0}.second.value());'.format(name)
        ops_header.write('struct {} {{ uint32_t first; Id second; }};\n\n'.format(name))
    elif (operand_kind['kind'] == 'PairIdRefLiteralInteger'):
        operand_kind['size'] = 2
        operand_kind['def_name'] = name[0].lower() + name[1:]
        operand_kind['def_value'] = '{Id(), 0}'
        operand_kind['type'] = name
        operand_kind['push_words'] = lambda name: 'words.push_back({0}.first.value()); words.push_back((uint32_t){0}.second);'.format(name)
        ops_header.write('struct {} {{ Id first; uint32_t second; }};\n\n'.format(name))
    elif (operand_kind['kind'] == 'PairIdRefIdRef'):
        operand_kind['size'] = 2
        operand_kind['def_name'] = name[0].lower() + name[1:]
        operand_kind['def_value'] = '{Id(), Id()}'
        operand_kind['type'] = name
        operand_kind['push_words'] = lambda name: 'words.push_back({0}.first.value()); words.push_back({0}.second.value());'.format(name)
        ops_header.write('struct {} {{ Id first, second; }};\n\n'.format(name))
        continue
    else:
        raise TypeError("Unexpected operand {} of type {}".format(operand_kind['kind'], operand_kind['category']))

    if operand_kind['from_words'] is None:
        operand_kind['from_words'] = lambda name,kind=operand_kind: '({}){}'.format(kind['type'], name)

ops_header.write('''
template<>
inline PairIdRefIdRef DecodeParam(const ConstIter &it, uint32_t &word)
{
  if(word >= it.size()) return {};
  
  PairIdRefIdRef ret = { Id::fromWord(it.word(word)), Id::fromWord(it.word(word+1)) };
  word += 2;
  return ret;
}

template<>
inline PairLiteralIntegerIdRef DecodeParam(const ConstIter &it, uint32_t &word)
{
  if(word >= it.size()) return {};
  
  PairLiteralIntegerIdRef ret = { it.word(word), Id::fromWord(it.word(word+1)) };
  word += 2;
  return ret;
}

template<>
inline PairIdRefLiteralInteger DecodeParam(const ConstIter &it, uint32_t &word)
{
  if(word >= it.size()) return {};
  
  PairIdRefLiteralInteger ret = { Id::fromWord(it.word(word)), it.word(word+1) };
  word += 2;
  return ret;
}
''')

tostrs = ''
tostr_decls = ''

# Second pass to declare operand parameter structs in ops helper header
for operand_kind in spirv['operand_kinds']:
    name = operand_kind['kind']

    if not operand_kind['has_params']:
        if operand_kind['category'] == 'ValueEnum':
            ops_header.write('inline uint16_t OptionalWordCount(const {0} val) {{ return val != {0}::Invalid ? 1 : 0; }}\n\n'.format(name))
        continue

    values = ''
    set_unset = ''
    word_count_cases = ''
    decode_cases = ''
    encode_cases = ''
    constructors = ''
    tostr_cases = ''

    value_enum = operand_kind['category'] == 'ValueEnum'
    bit_enum = operand_kind['category'] == 'BitEnum'

    used = []

    for value in operand_kind['enumerants']:
        params = ''
        assign = ''
        ret_assign = ''

        new_value = value['value'] not in used
        used.append(value['value'])

        if new_value and bit_enum:
            tostr_cases  += '  if(el.flags & {0}::{1})\n    ret += "{1}"'.format(name, value['enumerant'])

        if 'parameters' in value:
            # We want plain unions, so don't include strings
            if any([param['kind'] == 'LiteralString' for param in value['parameters']]):
                continue

            if new_value and value_enum:
                tostr_cases  += '    case {0}::{1}:\n      ret += '.format(name, value['enumerant'])

            member = ""
            param_name = operand_name(value['enumerant'])
            size = 0

            if new_value:
                if value_enum:
                    decode_cases += '    case {0}::{1}:\n'.format(name, value['enumerant'])
                    encode_cases += '    case {0}::{1}:\n'.format(name, value['enumerant'])
                else:
                    decode_cases += '  if(ret.flags & {0}::{1})\n  {{\n'.format(name, value['enumerant'])
                    encode_cases += '  if(param.flags & {0}::{1})\n  {{\n'.format(name, value['enumerant'])

            # if we only have one parameter, add its type to the set
            if len(value['parameters']) == 1:
                param = value['parameters'][0]
                size += kinds[param['kind']]['size']
                param_type = kinds[param['kind']]['type']
                member = "{} {};\n".format(param_type, param_name)

                if value_enum:
                    values += '  '
                    if new_value:
                        decode_cases += '  '
                        encode_cases += '  '

                values += "  " + member
                params += "{} {}Param".format(param_type, param_name)
                assign += " {0} = {0}Param;".format(param_name)
                ret_assign += "    ret.{0} = {0};\n".format(param_name)

                if new_value:
                    decode_cases += '    ret.{} = {};\n'.format(param_name, kinds[param['kind']]['from_words']('it.word(word)'))
                    encode_cases += '    {}\n'.format(kinds[param['kind']]['push_words']('param.{}'.format(param_name)))
                    if kinds[param['kind']]['is_id']:
                        tostr_cases += ' "(" + idName(el.{}) + ")"'.format(param_name)
                    else:
                        tostr_cases += ' "(" + ToStr(el.{}) + ")"'.format(param_name)

            # if we have multiple we need a separate struct for this thing
            else:
                struct_name = param_name[0].upper() + param_name[1:] + 'Params'
                member = "{} {};\n".format(struct_name, param_name)
                if value_enum:
                    values += '  '
                values += "  " + member

                struct_values = ''

                if new_value:
                    tostr_cases += ' "("'

                for i,param in enumerate(value['parameters']):
                    subparam_name = positional_names[i]
                    kind = kinds[param['kind']]
                    size += kind['size']
                    if 'name' in param:
                        subparam_name = operand_name(param['name'])
                    struct_values += "  {} {};\n".format(kind['type'], subparam_name)

                    if new_value:
                        if value_enum:
                            decode_cases += '  '
                            encode_cases += '  '
                        decode_cases += '    ret.{}.{} = {};\n'.format(param_name, subparam_name, kinds[param['kind']]['from_words']('it.word(word+{})'.format(i)))
                        encode_cases += '    {}\n'.format(kinds[param['kind']]['push_words']('param.{}.{}'.format(param_name, subparam_name)))
                        if kinds[param['kind']]['is_id']:
                            tostr_cases  += ' + idName(el.{}.{}) + '.format(param_name, subparam_name)
                        else:
                            tostr_cases  += ' + ToStr(el.{}.{}) + '.format(param_name, subparam_name)

                    assign += " {0}.{1} = {1};".format(param_name, subparam_name)
                    ret_assign += "    ret.{0}.{1} = {0}.{1};\n".format(param_name, subparam_name)
                    params += "{} {}".format(kind['type'], subparam_name)
                    if i != len(value['parameters'])-1:
                        params += ", "
                        tostr_cases += '", " '

                if new_value:
                    tostr_cases += '")"'

                header.write('''struct {struct_name}
{{
{struct_values}
}};

'''.format(struct_name = struct_name, struct_values = struct_values.rstrip()))

            if new_value:
                if value_enum:
                    decode_cases += '      word += {};\n'.format(size)
                    decode_cases += '      break;\n'
                    encode_cases += '      break;\n'
                    tostr_cases  += '; break;\n'
                else:
                    decode_cases += '    word += {};\n'.format(size)
                    decode_cases += '  }\n'
                    encode_cases += '  }\n'
                word_count_cases += '    case {}::{}: return {};\n'.format(name, value['enumerant'], size)

                constructors += '''template<>\nstruct {name}Param<{name}::{value}>
{{
  {member}
  {name}Param({params}) {{ {assign} }}
  operator {name}AndParamData()
  {{
    {name}AndParamData ret({name}::{value});
{ret_assign}
    return ret;
  }}
}};

'''.format(value=value['enumerant'], member=member.rstrip(), name=name, params=params, assign=assign, ret_assign=ret_assign.rstrip())

        if new_value and bit_enum:
            tostr_cases  += ' ", ";\n'


        set_unset += '''  void set{flag}({params}) {{ flags |= {name}::{flag};{assign} }}
  void unset{flag}() {{ flags &= ~{name}::{flag}; }}
'''.format(flag=value['enumerant'], name=name, params=params, assign=assign)

    if constructors != '':
        constructors = 'template<{name} val> struct {name}Param;\n\n'.format(name=name) + constructors

    # ValueEnums are set up as one or many pairs of enum/params, enum/params, etc. So we declare a struct for the pair
    # and declare an array if the op wants many
    if value_enum:
        tostrs += '''template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcspv::{name}AndParamData &el)
{{
  rdcstr ret = ToStr(el.value);

  switch(el.value)
  {{
{tostr_cases}
    default:
      break;
  }}

  return ret;
}}

'''.format(name=name, tostr_cases=tostr_cases.rstrip())

        tostr_decls += '''template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcspv::{name}AndParamData &el);'''.format(name=name)

        header.write('''struct {name}AndParamData
{{
  {name}AndParamData({name} v = {name}::Invalid) : value(v) {{}}
  {name} value;
  union
  {{
{values}
  }};
  
  operator {name}() const {{ return value; }}
  bool operator ==(const {name} v) const {{ return value == v; }}
}};

'''.format(name=name, values=values.rstrip()))

        ops_header.write('''{constructors}

template<>
inline {name}AndParamData DecodeParam(const ConstIter &it, uint32_t &word)
{{
  if(word >= it.size()) return {name}AndParamData();

  {name}AndParamData ret(({name})it.word(word));
  word++;
  switch(ret.value)
  {{
{decode_cases}
    default: break;
  }}
  return ret;
}}

inline void EncodeParam(rdcarray<uint32_t> &words, const {name}AndParamData &param)
{{
  words.push_back((uint32_t)param.value);
  switch(param.value)
  {{
{encode_cases}
    default: break;
  }}
}}

'''.format(name=name, value_name=operand_name(name), decode_cases=decode_cases.rstrip(),
           constructors=constructors, encode_cases=encode_cases.rstrip()))
        operand_kind['type'] = '{}AndParamData'.format(name)
    # BitEnums are set up with one bitmask, and then a series of parameters, so we declare a struct with an array
    elif bit_enum:
        tostrs += '''template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcspv::{name}AndParamDatas &el)
{{
  rdcstr ret;
  
{tostr_cases}

  // remove trailing ", "
  if(ret.size() > 2)
    ret.erase(ret.size()-2, 2);

  return ret;
}}

'''.format(name=name, tostr_cases=tostr_cases.rstrip())

        tostr_decls += '''template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcspv::{name}AndParamDatas &el);'''.format(name=name)

        header.write('''struct {name}AndParamDatas
{{
  {name}AndParamDatas({name} f = {name}::None) : flags(f) {{}}
  {name} flags;
{values}
  
  operator {name}() const {{ return flags; }}
  bool operator &(const {name} v) const {{ return bool(flags & v); }}
{set_unset}
}};

'''.format(name=name, values=values.rstrip(), set_unset=set_unset.rstrip()))

        ops_header.write('''template<>
inline {name}AndParamDatas DecodeParam(const ConstIter &it, uint32_t &word)
{{
  if(word >= it.size()) return {name}AndParamDatas();

  {name}AndParamDatas ret(({name})it.word(word));
  word++;
{decode_cases}
  return ret;
}}

inline void EncodeParam(rdcarray<uint32_t> &words, const {name}AndParamDatas &param)
{{
  words.push_back((uint32_t)param.flags);
{encode_cases}
}}

'''.format(name=name,  decode_cases=decode_cases.rstrip(), encode_cases=encode_cases.rstrip()))
        operand_kind['type'] = '{}AndParamDatas'.format(name)
    else:
        raise TypeError("unexpected operand kind {} with parameters".format(operand_kind['category']))

    ops_header.write('''inline uint16_t ExtraWordCount(const {name} {value_name})
{{
  switch({value_name})
  {{
{word_count_cases}
    default: break;
  }}
  return 0;
}}

'''.format(name = name, value_name = operand_name(name), word_count_cases = word_count_cases.rstrip()))

ops_header.write('''
inline uint16_t ExtraWordCount(const rdcstr &val)
{
  return uint16_t(val.size() / 4);
}

inline uint16_t OptionalWordCount(const rdcstr &val)
{
  if(val.empty()) return 0;
  return uint16_t(val.size() / 4) + 1;
}

inline uint16_t OptionalWordCount(const Id &val)
{
  return val != Id() ? 1 : 0;
}

inline uint16_t OptionalWordCount(const PairIdRefLiteralInteger &val)
{
  return val.first != Id() ? 2 : 0;
}

inline uint16_t OptionalWordCount(const PairLiteralIntegerIdRef &val)
{
  return val.second != Id() ? 2 : 0;
}

inline uint16_t OptionalWordCount(const PairIdRefIdRef &val)
{
  return val.first != Id() ? 2 : 0;
}

template<typename Type>
uint16_t MultiWordCount(const rdcarray<Type> &multiParams)
{
  uint16_t ret = 0;
  for(size_t i=0; i < multiParams.size(); i++)
    ret += sizeof(multiParams[i])/sizeof(uint32_t);
  return ret;
}
''')

###############################################################################
##
## Opcodes (declare enum / stringise)
##
###############################################################################

# Quickly preprocess, find parameters with duplicated names and disambiguate
for inst in spirv['instructions']:
    if 'operands' in inst:
        operands = inst['operands']

        duplicates = []

        for i,A in enumerate(operands):
            for j,B in enumerate(operands):
                if j <= i:
                    continue

                a = operand_name(A['name'] if 'name' in A else kinds[A['kind']]['def_name'])
                b = operand_name(B['name'] if 'name' in B else kinds[B['kind']]['def_name'])

                if a == b:
                    if i not in duplicates:
                        duplicates.append(i)
                    if j not in duplicates:
                        duplicates.append(j)

        if len(duplicates) > 0:
            for idx,arg in enumerate(duplicates):
                A = operands[arg]
                operands[arg]['name'] = operand_name(A['name'] if 'name' in A else kinds[A['kind']]['def_name']) + str(idx)

used = []
decl = ''
stringise = ''
op_structs = ''
op_decoder = ''
used_ids = ''
disassemble = ''

for inst in spirv['instructions']:
    decl += '  {} = {},\n'.format(inst['opname'][2:], inst['opcode'])

    if inst['opcode'] in used:
        continue

    stringise += '    STRINGISE_ENUM_CLASS({});\n'.format(inst['opname'][2:])

    result = -1
    resultType = -1

    used_ids += '    case rdcspv::Op::{}:\n'.format(inst['opname'][2:])

    operands = []
    
    if 'operands' in inst:
        operands = inst['operands']
        last_operand = operands[-1]

        for i,operand in enumerate(operands):
            if operand['kind'] == 'IdResult':
                result = i+1
            if operand['kind'] == 'IdResultType':
                resultType = i+1

    disassemble += '    case rdcspv::Op::{}:\n'.format(inst['opname'][2:])
    disassemble += '    {\n'

    if any([kinds[operand['kind']]['size'] is None for operand in operands]):
        op_struct = 'struct {}; // has operands with variable sizes\n\n'.format(inst['opname'])

        disassemble += '      OpDecoder decoded(it);\n'.format(inst['opname'][2:])

        if resultType > 0 and result > 0:
            disassemble += '      ret += declName(decoded.resultType, decoded.result) + " = ";\n'
        elif resultType > 0 and result == -1:
            raise ValueError("Unexpected result type without result")
        elif resultType == -1 and result > 0:
            disassemble += '      ret += idName(decoded.result) + " = ";\n'

        disassemble += '      ret += "{}(...)";\n'.format(inst['opname'][2:])
        disassemble += '      break;\n'
        disassemble += '    }\n'
    else:
        params = ''
        assign = ''
        member_decl = ''
        size_name = 'FixedWordSize'
        construct_size = 'FixedWordSize'
        size = 1 # opcode / wordcount packed
        all_size = 1 # size, but with all optionals included
        iter_init = '    memcpy(this, &(*it), sizeof(*this));'
        complex_type = False
        manual_init = '    this->op = OpCode;\n'
        manual_init += '    this->wordCount = (uint16_t)it.size();\n'
        oper_cast = '  operator Operation() const\n  {\n    rdcarray<uint32_t> words;\n'
        has_funcs = ''

        disassemble += '      Op{} decoded(it);\n'.format(inst['opname'][2:])
        
        if resultType > 0 and result > 0:
            disassemble += '      ret += declName(decoded.resultType, decoded.result) + " = ";\n'
        elif resultType > 0 and result == -1:
            raise ValueError("Unexpected result type without result")
        elif resultType == -1 and result > 0:
            disassemble += '      ret += idName(decoded.result) + " = ";\n'

        disassemble += '      ret += "{}("'.format(inst['opname'][2:])

        disassemble_params = False

        if 'operands' in inst:
            for i,operand in enumerate(operands):
                kind = kinds[operand['kind']]

                if kind['has_params'] and not complex_type:
                    size_name = 'MinWordSize'
                    construct_size = 'MinWordSize'
                    complex_type = True
                    manual_init += '    uint32_t word = {};\n'.format(all_size)

                if kind['is_id']:
                    used_ids += '      usedids.insert(Id::fromWord(it.word({})));\n'.format(all_size)

                quantifier = ''

                if 'quantifier' in operand:
                    quantifier = operand['quantifier']
                    if not complex_type:
                        size_name = 'MinWordSize'
                        construct_size = 'MinWordSize'
                        complex_type = True
                        if quantifier == '*':
                            manual_init += '    uint32_t word = {};\n'.format(all_size)

                if kind['size'] < 0:
                    size_name = 'MinWordSize'
                    construct_size = 'MinWordSize'
                    complex_type = True
                    manual_init += '    uint32_t word = {};\n'.format(all_size)

                opType,opName = (kind['type'], operand_name(operand['name'] if 'name' in operand else kind['def_name']))

                if i+1 != resultType and i+1 != result:
                    if quantifier == '*':
                        disassemble += ' + ParamsToStr(idName, decoded.{})'.format(opName)
                    else:
                        if opType == 'IdScope':
                            disassemble += ' + ToStr(Scope(constIntVal(decoded.{})))'.format(opName)
                        elif opType == 'IdMemorySemantics':
                            disassemble += ' + ToStr(MemorySemantics(constIntVal(decoded.{})))'.format(opName)
                        else:
                            disassemble += ' + ParamToStr(idName, decoded.{})'.format(opName)

                    if i+1 < len(operands):
                        disassemble += ' + ", "'

                    disassemble_params = True

                if quantifier == '?':
                    params += '{} {} = {}, '.format(opType, opName, kind['def_value'])
                elif quantifier == '*':
                    params += 'const rdcarray<{}> &{} = {{}}, '.format(opType, opName)
                else:
                    params += '{} {}, '.format(opType, opName)

                if quantifier == '*':
                    member_decl += '  rdcarray<{}> {};\n'.format(opType, opName)
                else:
                    member_decl += '  {} {};\n'.format(opType, opName)
                assign += '    this->{} = {};\n'.format(opName, opName)

                if operand['kind'] == 'LiteralString':
                    if quantifier == '*':
                        raise ValueError('operand {} in {} is string but has * quantifier'.format(opName, inst['opname']))
                    manual_init += '    this->{name} = DecodeParam<{type}>(it, word);\n'.format(name = opName, type = opType)
                    oper_cast += '    EncodeParam(words, {name});\n'.format(name = opName)
                    if quantifier == '?':
                        construct_size += ' + OptionalWordCount({})'.format(opName)
                        has_funcs += '  bool Has{name}() const {{ return wordCount > {idx}; }}\n'.format(idx = all_size, name = opName[0].upper() + opName[1:])
                    else:
                        construct_size += ' + ExtraWordCount({})'.format(opName)

                elif kind['has_params']:
                    if quantifier == '*':
                        raise ValueError('operand {} in {} has * quantifier and params'.format(opName, inst['opname']))
                    manual_init += '    this->{name} = DecodeParam<{type}>(it, word);\n'.format(name = opName, type = opType)
                    oper_cast += '    EncodeParam(words, {name});\n'.format(name = opName)
                    construct_size += ' + ExtraWordCount({})'.format(opName)
                elif quantifier == '*':
                    manual_init += '    this->{name} = MultiParam<{type}>(it, word);\n'.format(name = opName, type = opType)
                    construct_size += ' + MultiWordCount({})'.format(opName)
                    oper_cast += '    for(size_t i=0; i < {name}.size(); i++)\n'.format(name = opName)
                    oper_cast += '    {\n'
                    oper_cast += '      {push_words}\n'.format(push_words = kind['push_words']('{}[i]'.format(opName)))
                    oper_cast += '    }\n'
                elif quantifier == '?':
                    manual_init += '    this->{name} = (it.size() > {idx}) ? {value} : {def_value};\n'.format(name = opName, type = opType, idx = all_size, value = kind['from_words']('it.word({})'.format(all_size)), def_value=kind['def_value'])
                    construct_size += ' + OptionalWordCount({})'.format(opName)
                    oper_cast += '    if({name} != {def_value}) {push_words}\n'.format(name = opName, def_value=kind['def_value'], push_words = kind['push_words'](opName))
                    has_funcs += '  bool Has{name}() const {{ return wordCount > {idx}; }}\n'.format(idx = all_size, name = opName[0].upper() + opName[1:])
                else:
                    manual_init += '    this->{name} = {value};\n'.format(name = opName, type = opType, value = kind['from_words']('it.word({})'.format(all_size)))
                    oper_cast += '    {push_words}\n'.format(push_words = kind['push_words'](opName))

                if kind['size'] >= 0:
                    all_size += kind['size']
                else:
                    all_size += 1

                if quantifier == '':
                    size = all_size
        else:
            assign = '    // no operands'
            member_decl = '  // no operands'

        if complex_type:
            iter_init = manual_init.rstrip()
            oper_cast += '    return Operation(OpCode, words);\n  }\n'
        else:
            oper_cast = ''

        if params != '':
            params = params[0:-2]

        if disassemble_params:
            disassemble += ' + ")";\n'
        else:
            disassemble += ' ")";\n'
        disassemble += '      break;\n'
        disassemble += '    }\n'

        if has_funcs != '':
            has_funcs = '\n\n' + has_funcs

        op_struct = '''struct {name}
{{
  {name}(const ConstIter &it)
  {{
{iter_init}
  }}
  {name}({params})
      : op(Op::{opname})
      , wordCount({construct_size})
  {{
{assign}
  }}
{oper_cast}
  static constexpr Op OpCode = Op::{opname};
  static constexpr uint16_t {size_name} = {size}U;
  Op op;
  uint16_t wordCount;
{member_decl}{has_funcs}
}};

'''.format(opname=inst['opname'][2:], name=inst['opname'], params=params, iter_init=iter_init, assign=assign.rstrip(),
           member_decl=member_decl.rstrip(), size_name=size_name, construct_size=construct_size,
           oper_cast=oper_cast, size=size, has_funcs=has_funcs.rstrip())


    op_structs += op_struct

    # Sanity check that quantifiers only happen on final operands. Also if there are multiple they are all ?, not *
    if 'operands' in inst:
        operands = inst['operands']
        last_operand = operands[-1]

        for operand in operands:
            if operand != last_operand and 'quantifier' in operand and ('quantifier' not in last_operand or last_operand['quantifier'] != operand['quantifier'] or operand['quantifier'] != '?'):
                raise ValueError('quantifier on operand {} in {} but not on last operand'.format(operand['name'], inst['opname']))

    used_ids += '      break;\n'

    if result < 0:
        result = ' result = Id();'
    else:
        result = ' result = Id::fromWord(it.word({}));'.format(result)

    if resultType < 0:
        resultType = ' resultType = Id();'
    else:
        resultType = ' resultType = Id::fromWord(it.word({}));'.format(resultType)

    op_decoder += '    case rdcspv::Op::{}:{}{} break;\n'.format(inst['opname'][2:], result, resultType)

    used.append(inst['opcode'])

header.write('''enum class Op : uint16_t
{{
{decl}
  Max,
}};

'''.format(decl = decl))

ops_header.write('''
{op_structs}

template<typename T>
inline rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const T &el)
{{
  return ToStr(el);
}}

template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const Id &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcstr &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairLiteralIntegerIdRef &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefLiteralInteger &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefIdRef &el);

{tostr_decls}

template<typename U>
inline rdcstr ParamsToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcarray<U> &ids)
{{
  rdcstr ret = "{{";
  for(size_t i=0; i < ids.size(); i++)
  {{
    ret += ParamToStr(idName, ids[i]);
    if(i + 1 < ids.size())
      ret += ", ";
  }}
  ret += "}}";
  return ret;
}}

struct OpDecoder
{{
  OpDecoder(const ConstIter &it);

  static void AddUsedIDs(std::set<Id> &usedids, const ConstIter &it);
  static rdcstr Disassemble(const ConstIter &it, const std::function<rdcstr(Id,Id)> &declName, const std::function<rdcstr(rdcspv::Id)> &idName, const std::function<uint32_t(Id)> &constIntVal);
  
  Op op;
  uint16_t wordCount;
  Id result;
  Id resultType;
}};
'''.format(op_structs = op_structs.rstrip(), tostr_decls = tostr_decls))

cpp.write('''template <>
rdcstr DoStringise(const rdcspv::Op &el)
{{
  BEGIN_ENUM_STRINGISE(rdcspv::Op);
  {{
{stringise}
  }}
  END_ENUM_STRINGISE();
}}

namespace rdcspv
{{

template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const Id &el)
{{
  return idName(el);
}}

template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcstr &el)
{{
  return "\\"" + el + "\\"";
}}

template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairLiteralIntegerIdRef &el)
{{
  return StringFormat::Fmt("[%u, %s]", el.first, idName(el.second).c_str());
}}

template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefLiteralInteger &el)
{{
  return StringFormat::Fmt("[%s, %u]", idName(el.first).c_str(), el.second);
}}

template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefIdRef &el)
{{
  return StringFormat::Fmt("[%s, %s]", idName(el.first).c_str(), idName(el.second).c_str());
}}

{tostrs}

void OpDecoder::AddUsedIDs(std::set<Id> &usedids, const ConstIter &it)
{{
  switch(it.opcode())
  {{
{used_ids}
    case Op::Max: break;
  }}
}}

rdcstr OpDecoder::Disassemble(const ConstIter &it, const std::function<rdcstr(Id,Id)> &declName, const std::function<rdcstr(rdcspv::Id)> &idName, const std::function<uint32_t(Id)> &constIntVal)
{{
  rdcstr ret;
  switch(it.opcode())
  {{
{disassemble}
    case Op::Max: break;
  }}
  return ret;
}}
  
OpDecoder::OpDecoder(const ConstIter &it)
{{
  op = it.opcode();
  wordCount = (uint16_t)it.size();
  switch(op)
  {{
{op_decoder}
    case Op::Max: break;
  }}
}}

}}; // namespace rdcspv

'''.format(stringise = stringise.rstrip(), op_decoder = op_decoder.rstrip(), used_ids = used_ids.rstrip(), disassemble = disassemble.rstrip(), tostrs = tostrs.rstrip()));

###############################################################################
##
## GLSL ext inst set (declare enum)
##
###############################################################################

decl = ''
stringise = ''

for inst in glsl450['instructions']:
    decl += '  {} = {},\n'.format(inst['opname'], inst['opcode'])
    stringise += '    STRINGISE_ENUM_CLASS({});\n'.format(inst['opname'])

header.write('''enum class GLSLstd450 : uint32_t
{{
{decl}
  Max,
  Invalid = ~0U,
}};

'''.format(decl = decl))

cpp.write('''template <>
rdcstr DoStringise(const rdcspv::GLSLstd450 &el)
{{
  BEGIN_ENUM_STRINGISE(rdcspv::GLSLstd450);
  {{
{stringise}
  }}
  END_ENUM_STRINGISE();
}}

template <>
rdcstr DoStringise(const rdcspv::Generator &el)
{{
  BEGIN_ENUM_STRINGISE(rdcspv::Generator);
  {{
{generator_tostr}
  }}
  END_ENUM_STRINGISE();
}}
'''.format(stringise = stringise.rstrip(), generator_tostr = generator_tostr.rstrip()))

header.write('''
}; // namespace rdcspv

DECLARE_STRINGISE_TYPE(rdcspv::GLSLstd450);
DECLARE_STRINGISE_TYPE(rdcspv::Generator);

''')

for operand_kind in spirv['operand_kinds']:
    if operand_kind['category'] == 'ValueEnum' or operand_kind['category'] == 'BitEnum':
        header.write('DECLARE_STRINGISE_TYPE(rdcspv::{});\n'.format(operand_kind['kind']))

ops_header.write('''
}; // namespace rdcspv
''')

header.close()
ops_header.close()
cpp.close()
