package main

import (
	"bytes"
	"fmt"
	"log"
	"os"
	"sort"
	"strconv"
	"strings"
	"unicode"

	"github.com/goccy/go-yaml"
	"github.com/lestrrat-go/codegen"
)

// Define the structs with exported fields for proper YAML unmarshaling
type AlgYAML struct {
	Algorithms []Algorithm `yaml:"algorithms"`
}

type Algorithm struct {
	Name      string    `yaml:"name"`
	Comment   string    `yaml:"comment"`
	Filename  string    `yaml:"filename"`
	Elements  []Element `yaml:"elements"`
	Symmetric bool      `yaml:"symmetric"`
}

type Element struct {
	Name             string `yaml:"name"`
	Value            string `yaml:"value"`
	TokenReference   string `yaml:"token_reference"`
	ReturnvalComment string `yaml:"returnval_comment"`
	Comment          string `yaml:"comment"`
	Invalid          bool   `yaml:"invalid"`
	Sym              bool   `yaml:"sym"`
	Deprecated       bool   `yaml:"deprecated"`
}

func main() {
	if err := _main(); err != nil {
		log.Printf("%s", err)
		os.Exit(1)
	}
}

func _main() error {
	// Default to "objects.yml" if no argument is provided
	filename := "objects.yml"

	// If command line arguments are provided, use the first one as the file path
	if len(os.Args) > 1 {
		filename = os.Args[1]
	}

	// Read the algorithm definitions from the specified file
	yamlFile, err := os.ReadFile(filename)
	if err != nil {
		return fmt.Errorf("failed to read %s: %w", filename, err)
	}

	var algs AlgYAML
	if err := yaml.Unmarshal(yamlFile, &algs); err != nil {
		return fmt.Errorf("failed to unmarshal %s: %w", filename, err)
	}

	algorithms := algs.Algorithms

	sort.Slice(algorithms, func(i, j int) bool {
		return algorithms[i].Name < algorithms[j].Name
	})

	for _, t := range algorithms {
		t := t // Make a copy for the closure

		sort.Slice(t.Elements, func(i, j int) bool {
			return t.Elements[i].Name < t.Elements[j].Name
		})

		if err := Generate(t); err != nil {
			return fmt.Errorf(`failed to generate file: %w`, err)
		}
		if err := GenerateTest(t); err != nil {
			return fmt.Errorf(`failed to generate test file: %w`, err)
		}
	}
	return nil
}

func Generate(t Algorithm) error {
	var buf bytes.Buffer

	if t.Filename == "" {
		return fmt.Errorf("filename is empty for type %q", t.Name)
	}

	o := codegen.NewOutput(&buf)

	o.R("// Code generated by tools/cmd/genjwa/main.go. DO NOT EDIT.")
	o.LL("package jwa")

	o.LL("import (")
	pkgs := []string{
		"encoding/json",
		"fmt",
		"sort",
		"sync",
	}
	
	// Check if we need to import tokens package
	needsTokens := false
	for _, e := range t.Elements {
		if e.TokenReference != "" {
			needsTokens = true
			break
		}
	}
	
	if needsTokens {
		pkgs = append(pkgs, "github.com/lestrrat-go/jwx/v3/internal/tokens")
	}
	
	for _, pkg := range pkgs {
		o.L("%s", strconv.Quote(pkg))
	}
	o.L(")")

	o.LL("var muAll%s sync.RWMutex", t.Name)
	o.L("var all%[1]s = map[string]%[1]s{}", t.Name)
	o.L("var muList%s sync.RWMutex", t.Name)
	o.L("var list%s []%s", t.Name, t.Name)
	o.L("var builtin%s = map[string]struct{}{}", t.Name)

	o.LL("func init() {")
	o.L("// builtin values for %s", t.Name)
	// check if we have invalid elements, so we allocate just enough
	// space for the builtin algorithms
	invalids := 0
	for _, e := range t.Elements {
		if e.Invalid {
			invalids++
		}
	}
	o.L("algorithms := make([]%s, %d)", t.Name, len(t.Elements)-invalids)
	ecount := 0
	for _, e := range t.Elements {
		if e.Invalid {
			continue
		}
		valueRef := fmt.Sprintf("%q", e.Value)
		if e.TokenReference != "" {
			valueRef = e.TokenReference
		}
		o.L("algorithms[%d] = New%s(%s", ecount, t.Name, valueRef)
		ecount++

		if e.Deprecated {
			o.R(", WithDeprecated(true)")
		}
		if e.Sym {
			o.R(", WithIsSymmetric(true)")
		}
		o.R(")")
	}

	o.LL("Register%s(algorithms...)", t.Name)
	o.L("}") // end init

	// Accessors for builtin algorithms
	for _, e := range t.Elements {
		if e.Invalid {
			o.L("var %s = New%s(%q)", fmt.Sprintf("%c%s", unicode.ToLower(rune(e.Name[0])), e.Name[1:]), t.Name, e.Value)
		}

		if e.Value == "" || e.ReturnvalComment != "" {
			if e.ReturnvalComment == "" {
				return fmt.Errorf("missing value for %s (required if e.Value is empty)", e.Name)
			}
			o.LL("// %s returns an object representing %s.", e.Name, e.ReturnvalComment)
		} else {
			o.LL("// %s returns an object representing %s.", e.Name, e.Value)
		}
		if e.Comment != "" {
			o.R(" %s", e.Comment)
		}
		o.L("func %s() %s {", e.Name, t.Name)
		if e.Invalid {
			o.L("return %s", fmt.Sprintf("%c%s", unicode.ToLower(rune(e.Name[0])), e.Name[1:]))
		} else {
			valueRef := fmt.Sprintf("%q", e.Value)
			if e.TokenReference != "" {
				valueRef = e.TokenReference
			}
			o.L("return lookupBuiltin%s(%s)", t.Name, valueRef)
		}
		o.L("}")
	}

	o.LL("func lookupBuiltin%s(name string) %s {", t.Name, t.Name)
	o.L("muAll%s.RLock()", t.Name)
	o.L("v, ok := all%s[name]", t.Name)
	o.L("muAll%s.RUnlock()", t.Name)
	o.L("if !ok {")
	o.L("panic(fmt.Sprintf(`jwa: %s %%q not registered`, name))", t.Name)
	o.L("}")
	o.L("return v")
	o.L("}")

	o.LL("// %s", t.Comment)
	o.L("type %s struct {", t.Name)
	o.L("name string")
	o.L("deprecated bool")
	if t.Symmetric {
		o.L("isSymmetric bool")
	}
	o.L("}")

	o.LL("func (s %s) String() string {", t.Name)
	o.L("return s.name")
	o.L("}")

	o.LL("// IsDeprecated returns true if the %s object is deprecated.", t.Name)
	o.L("func (s %s) IsDeprecated() bool {", t.Name)
	o.L("return s.deprecated")
	o.L("}")

	if t.Symmetric {
		o.LL("// IsSymmetric returns true if the %s object is symmetric. Symmetric algorithms use the same key for both encryption and decryption.", t.Name)
		o.L("func (s %s) IsSymmetric() bool {", t.Name)
		o.L("return s.isSymmetric")
		o.L("}")
	}

	o.LL("// Empty%[1]s returns an empty %[1]s object, used as a zero value.", t.Name)
	o.L("func Empty%s() %s {", t.Name, t.Name)
	o.L("return %s{}", t.Name)
	o.L("}")

	o.LL("// New%[1]s creates a new %[1]s object with the given name.", t.Name)

	o.L("func New%[1]s(name string", t.Name)
	if t.Symmetric {
		o.R(", options ...New%[1]sOption", t.Name)
	} else {
		o.R(", options ...NewAlgorithmOption")
	}
	o.R(") %[1]s {", t.Name)
	o.L("var deprecated bool")
	if t.Symmetric {
		o.L("var isSymmetric bool")
	}
	o.L("for _, option := range options {")
	o.L("switch option.Ident() {")
	if t.Symmetric {
		o.L("case identIsSymmetric{}:")
		o.L("if err := option.Value(&isSymmetric); err != nil {")
		o.L("panic(\"jwa.New%s: WithIsSymmetric option must be a boolean\")", t.Name)
		o.L("}")
	}
	o.L("case identDeprecated{}:")
	o.L("if err := option.Value(&deprecated); err != nil {")
	o.L("panic(\"jwa.New%s: WithDeprecated option must be a boolean\")", t.Name)
	o.L("}")
	o.L("}")
	o.L("}")
	o.L("return %s{name: name, deprecated: deprecated", t.Name)
	if t.Symmetric {
		o.R(", isSymmetric: isSymmetric")
	}
	o.R("}")
	o.L("}")

	o.LL("// Lookup%[1]s returns the %[1]s object for the given name.", t.Name)
	o.L("func Lookup%[1]s(name string) (%[1]s, bool) {", t.Name)
	o.L("muAll%[1]s.RLock()", t.Name)
	o.L("v, ok := all%[1]s[name]", t.Name)
	o.L("muAll%[1]s.RUnlock()", t.Name)
	o.L("return v, ok")
	o.L("}")

	o.LL("// Register%[1]s registers a new %[1]s. The signature value must be immutable", t.Name)
	o.L("// and safe to be used by multiple goroutines, as it is going to be shared with all other users of this library.")
	o.L("func Register%[1]s(algorithms ...%[1]s) {", t.Name)
	o.L("muAll%[1]s.Lock()", t.Name)
	o.L("for _, alg := range algorithms {")
	o.L("all%[1]s[alg.String()] = alg", t.Name)
	o.L("}")
	o.L("muAll%[1]s.Unlock()", t.Name)
	o.L("rebuild%[1]s()", t.Name)
	o.L("}")

	o.LL("// Unregister%[1]s unregisters a %[1]s from its known database.", t.Name)
	o.L("// Non-existent entries, as well as built-in algorithms will silently be ignored.")
	o.L("func Unregister%[1]s(algorithms ...%[1]s) {", t.Name)
	o.L("muAll%[1]s.Lock()", t.Name)
	o.L("for _, alg := range algorithms {")
	o.L("if _, ok := builtin%[1]s[alg.String()]; ok {", t.Name)
	o.L("continue")
	o.L("}")
	o.L("delete(all%[1]s, alg.String())", t.Name)
	o.L("}")
	o.L("muAll%[1]s.Unlock()", t.Name)
	o.L("rebuild%[1]s()", t.Name)
	o.L("}")

	o.LL("func rebuild%[1]s() {", t.Name)
	o.L("list := make([]%[1]s, 0, len(all%[1]s))", t.Name)
	o.L("muAll%[1]s.RLock()", t.Name)
	o.L("for _, v := range all%[1]s {", t.Name)
	o.L("list = append(list, v)")
	o.L("}")
	o.L("muAll%[1]s.RUnlock()", t.Name)
	o.L("sort.Slice(list, func(i, j int) bool {")
	o.L("return list[i].String() < list[j].String()")
	o.L("})")
	o.L("muList%[1]s.Lock()", t.Name)
	o.L("list%[1]s = list", t.Name)
	o.L("muList%[1]s.Unlock()", t.Name)
	o.L("}")

	o.LL("// %[1]ss returns a list of all available values for %[1]s.", t.Name)
	o.L("func %[1]ss() []%[1]s {", t.Name)
	o.L("muList%[1]s.RLock()", t.Name)
	o.L("defer muList%[1]s.RUnlock()", t.Name)
	o.L("return list%[1]s", t.Name)
	o.L("}")

	o.LL("// MarshalJSON serializes the %[1]s object to a JSON string.", t.Name)
	o.L("func (s %[1]s) MarshalJSON() ([]byte, error) {", t.Name)
	o.L("return json.Marshal(s.String())")
	o.L("}")

	o.LL("// UnmarshalJSON deserializes the JSON string to a %[1]s object.", t.Name)
	o.L("func (s *%[1]s) UnmarshalJSON(data []byte) error {", t.Name)
	o.L("var name string")
	o.L("if err := json.Unmarshal(data, &name); err != nil {")
	o.L("return fmt.Errorf(`failed to unmarshal %[1]s: %%w`, err)", t.Name)
	o.L("}")
	o.L("v, ok := Lookup%[1]s(name)", t.Name)
	o.L("if !ok {")
	o.L("return fmt.Errorf(`unknown %[1]s: %%q`, name)", t.Name)
	o.L("}")
	o.L("*s = v")
	o.L("return nil")
	o.L("}")

	if err := o.WriteFile(t.Filename, codegen.WithFormatCode(true)); err != nil {
		if cfe, ok := err.(codegen.CodeFormatError); ok {
			fmt.Fprint(os.Stderr, cfe.Source())
		}
		return fmt.Errorf(`failed to write to %s: %w`, t.Filename, err)
	}
	return nil
}

func GenerateTest(t Algorithm) error {
	var buf bytes.Buffer

	valids := make([]Element, 0, len(t.Elements))
	invalids := make([]Element, 0, len(t.Elements))
	for _, e := range t.Elements {
		if e.Invalid {
			invalids = append(invalids, e)
			continue
		}
		valids = append(valids, e)
	}

	o := codegen.NewOutput(&buf)
	o.R("// Code generated by tools/cmd/genjwa/main.go. DO NOT EDIT")
	o.LL("package jwa_test")

	o.L("import (")
	pkgs := []string{
		"strconv",
		"testing",
		"github.com/lestrrat-go/jwx/v3/jwa",
		"github.com/stretchr/testify/require",
	}
	for _, pkg := range pkgs {
		o.L("%s", strconv.Quote(pkg))
	}
	o.L(")")

	o.LL("func Test%s(t *testing.T) {", t.Name)
	o.L("t.Parallel()")
	for _, e := range valids {
		o.L("t.Run(`Lookup the object`, func(t *testing.T) {")
		o.L("t.Parallel()")
		o.L("v, ok := jwa.Lookup%s(%q)", t.Name, e.Value)
		o.L("require.True(t, ok, `Lookup should succeed`)")
		o.L("require.Equal(t, jwa.%s(), v, `Lookup value should be equal to constant`)", e.Name)
		o.L("})")

		o.L("t.Run(`Unmarshal the string %s`, func(t *testing.T) {", e.Value)
		o.L("t.Parallel()")
		o.L("var dst jwa.%s", t.Name)
		o.L("require.NoError(t, json.Unmarshal([]byte(strconv.Quote(%q)), &dst), `UnmarshalJSON is successful`)", e.Value)
		o.L("require.Equal(t, jwa.%s(), dst, `unmarshaled value should be equal to constant`)", e.Name)
		o.L("})")

		o.L("t.Run(`stringification for %s`, func(t *testing.T) {", e.Value)
		o.L("t.Parallel()")
		o.L("require.Equal(t, %#v, jwa.%s().String(), `stringified value matches`)", e.Value, e.Name)
		o.L("})")
	}

	o.L("t.Run(`Unmarshal should fail for invalid value (totally made up) string value`, func(t *testing.T) {")
	o.L("t.Parallel()")
	o.L("var dst jwa.%s", t.Name)
	o.L("require.Error(t, json.Unmarshal([]byte(`totallyInvalidValue`), &dst), `Unmarshal should fail`)")
	o.L("})")

	if t.Symmetric {
		o.L("t.Run(`check symmetric values`, func(t *testing.T) {")
		o.L("t.Parallel()")
		for _, e := range t.Elements {
			o.L("t.Run(`%s`, func(t *testing.T) {", e.Name)
			if e.Sym {
				o.L("require.True")
			} else {
				o.L("require.False")
			}
			o.R("(t, jwa.%[1]s().IsSymmetric(), `jwa.%[1]s returns expected value`)", e.Name)
			o.L("})")
		}
		o.L("})")
	}

	o.L("t.Run(`check list of elements`, func(t *testing.T) {")
	o.L("t.Parallel()")
	o.L("var expected = map[jwa.%s]struct{} {", t.Name)
	for _, e := range t.Elements {
		if !e.Invalid {
			o.L("jwa.%s(): {},", e.Name)
		}
	}
	o.L("}")
	o.L("for _, v := range jwa.%ss() {", t.Name)
	if t.Name == "EllipticCurveAlgorithm" {
		o.L("// There is no good way to detect from a test if es256k (secp256k1)")
		o.L("// is supported, so just allow it")
		o.L("if v.String() == `secp256k1` {")
		o.L("continue")
		o.L("}")
	}
	o.L("_, ok := expected[v]")
	o.L("require.True(t, ok, `%%q should be in the list for %s`, v)", t.Name)
	o.L("delete(expected, v)")
	o.L("}")
	o.L("require.Len(t, expected, 0)")
	o.L("})")
	o.L("}")

	o.LL("// Note: this test can NOT be run in parallel as it uses options with global effect.")
	o.L("func Test%sCustomAlgorithm(t *testing.T) {", t.Name)
	o.L("// These subtests can NOT be run in parallel as options with global effect change.")
	o.L("const customAlgorithmValue = `custom-algorithm`")
	if t.Symmetric {
		o.L("for _, symmetric := range []bool{true, false} {")
	}
	o.L(`customAlgorithm := jwa.New%[1]s(customAlgorithmValue`, t.Name)
	if t.Symmetric {
		o.R(`, jwa.WithIsSymmetric(symmetric)`)
	}
	o.R(`)`)
	o.L("// Unregister the custom algorithm, in case tests fail.")
	o.L("t.Cleanup(func() {")
	o.L("jwa.Unregister%[1]s(customAlgorithm)", t.Name)
	o.L("})")
	o.L("t.Run(`with custom algorithm registered`, func(t *testing.T) {")
	o.L("jwa.Register%[1]s(customAlgorithm)", t.Name)
	o.L("t.Run(`Lookup the object`, func(t *testing.T) {")
	o.L("t.Parallel()")
	o.L("v, ok := jwa.Lookup%[1]s(customAlgorithmValue)", t.Name)
	o.L("require.True(t, ok, `Lookup should succeed`)")
	o.L("require.Equal(t, customAlgorithm, v, `Lookup value should be equal to constant`)")
	o.L("})")
	o.L("t.Run(`Unmarshal custom algorithm`, func(t *testing.T) {")
	o.L("t.Parallel()")
	o.L("var dst jwa.%[1]s", t.Name)
	o.L("require.NoError(t, json.Unmarshal([]byte(strconv.Quote(customAlgorithmValue)), &dst), `Unmarshal is successful`)")
	o.L("require.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)")
	o.L("})")
	if t.Symmetric {
		o.L("t.Run(`check symmetric`, func(t *testing.T) {")
		o.L("t.Parallel()")
		o.L("require.Equal(t, symmetric, customAlgorithm.IsSymmetric(), `custom algorithm's symmetric attribute should match`)")
		o.L("})")
	}
	o.L("})")
	o.L("t.Run(`with custom algorithm deregistered`, func(t *testing.T) {")
	o.L("jwa.Unregister%[1]s(customAlgorithm)", t.Name)
	o.L("t.Run(`Lookup the object`, func(t *testing.T) {")
	o.L("t.Parallel()")
	o.L("_, ok := jwa.Lookup%[1]s(customAlgorithmValue)", t.Name)
	o.L("require.False(t, ok, `Lookup should fail`)")
	o.L("})")
	o.L("t.Run(`Unmarshal custom algorithm`, func(t *testing.T) {")
	o.L("t.Parallel()")
	o.L("var dst jwa.%[1]s", t.Name)
	o.L("require.Error(t, json.Unmarshal([]byte(customAlgorithmValue), &dst), `Unmarshal should fail`)")
	o.L("})")
	o.L("})")
	if t.Symmetric {
		o.L("}") // ending the for _, symmetric := range loop
	}
	o.L("}")

	filename := strings.Replace(t.Filename, "_gen.go", "_gen_test.go", 1)
	if err := o.WriteFile(filename, codegen.WithFormatCode(true)); err != nil {
		if cfe, ok := err.(codegen.CodeFormatError); ok {
			fmt.Fprint(os.Stderr, cfe.Source())
		}
		return fmt.Errorf(`failed to write to %s: %w`, filename, err)
	}
	return nil
}
