package engine import ( "database/sql" "encoding/hex" "errors" "fmt" "io" "log" "slices" "strconv" "strings" _ "github.com/mattn/go-sqlite3" engine "git.pablu.de/pablu/sqv-engine/sql" ) type Manager struct { parser *engine.Parser conn *sql.DB tables []Table } func NewManagerFromFile(sqlTxt string) *Manager { return &Manager{ parser: engine.NewParser(strings.NewReader(sqlTxt)), } } func NewManager(path string) (*Manager, error) { db, err := sql.Open("sqlite3", path) if err != nil { return nil, err } var sqls []string r, err := db.Query("SELECT name, sql FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%'") if err != nil { return nil, err } for r.Next() { var name, sql string if err := r.Scan(&name, &sql); err != nil { return nil, err } sqls = append(sqls, sql) } schema := strings.Join(sqls, ";\n") schema += ";" fmt.Println(schema) return &Manager{ parser: engine.NewParser(strings.NewReader(schema)), conn: db, }, nil } func (m *Manager) Start() error { createTableStatements := make([]*engine.CreateTableStatement, 0) for { stmt, err := m.parser.Parse() if err != nil && errors.Is(err, io.EOF) { fmt.Println("Finished parsing") break } else if err != nil { return err } switch v := stmt.(type) { case *engine.CreateTableStatement: t, err := m.convertCreateTableStatementToTable(v) if err != nil { return err } m.tables = append(m.tables, t) createTableStatements = append(createTableStatements, v) default: panic("NOT IMPLEMENTED") } } // Rethink how to do this cleanly for _, cts := range createTableStatements { err := m.references(cts) if err != nil { return err } } return nil } func (m *Manager) Refresh(sqlTxt string) error { m.parser = engine.NewParser(strings.NewReader(sqlTxt)) m.tables = make([]Table, 0) return m.Start() } func (m *Manager) GetTables() []Table { return m.tables } func (m *Manager) GetTable(name string) (Table, bool) { index := slices.IndexFunc(m.tables, func(t Table) bool { return t.Name == name }) if index < 0 { return Table{}, false } table := m.tables[index] return table, true } func (m *Manager) RunSql(sqlText string) (Table, error) { p := engine.NewParser(strings.NewReader(sqlText)) stmt, err := p.Parse() if err != nil && !errors.Is(err, io.EOF) { return Table{}, err } selectStmt, ok := stmt.(*engine.SelectStatement) if !ok { panic("HELP ITS NOT A SELECT STATMET") } table, ok := m.GetTable(selectStmt.From) if !ok { panic("HELP TABLE NOT FOUND") } table.Rows = make([]Row, 0) fields := make([]Column, 0) if slices.Contains(selectStmt.Fields, "*") { fields = table.Columns } else { for _, columnName := range selectStmt.Fields { index := slices.IndexFunc(table.Columns, func(c Column) bool { if c.Name == columnName { return true } return false }) fields = append(fields, table.Columns[index]) } } table.Columns = fields err = m.loadTableRaw(&table, fields, sqlText) return table, err } func (m *Manager) loadTableRaw(table *Table, fields []Column, s string, args ...any) error { rows, err := m.conn.Query(s, args...) if err != nil { return err } table.Rows = make([]Row, 0) for rows.Next() { cols := make([]any, len(fields)) for i, column := range fields { if column.Flags.Has(NOT_NULL) { switch column.Type { case BLOB: cols[i] = new([]byte) case TEXT: cols[i] = new(string) case INTEGER: cols[i] = new(int) case REAL: cols[i] = new(float64) default: panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") } } else { switch column.Type { case BLOB: cols[i] = new([]byte) case TEXT: cols[i] = new(sql.NullString) case INTEGER: cols[i] = new(sql.NullInt64) case REAL: cols[i] = new(sql.NullFloat64) default: panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") } } } err = rows.Scan(cols...) if err != nil { return err } table.Rows = append(table.Rows, Row{ Values: anyToStr(cols), }) } return nil } func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error { sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name) return m.loadTableRaw(table, table.Columns, sql, maxRows) } func (m *Manager) LoadTable(table *Table) error { sql := fmt.Sprintf("SELECT * FROM %v", table.Name) return m.loadTableRaw(table, table.Columns, sql) } func anyToStr(a []any) []string { res := make([]string, len(a)) for i, c := range a { switch v := c.(type) { case *string: res[i] = *v case *int: res[i] = strconv.Itoa(*v) case *float64: res[i] = strconv.FormatFloat(*v, 'f', 2, 64) case *[]byte: buf := *v if len(buf) > 512 { buf = buf[0:512] } res[i] = hex.EncodeToString(buf) case *sql.NullInt64: if v.Valid { res[i] = strconv.Itoa(int(v.Int64)) } else { res[i] = "NULL" } case *sql.NullFloat64: if v.Valid { res[i] = strconv.FormatFloat(v.Float64, 'f', 2, 64) } else { res[i] = "NULL" } case *sql.NullString: if v.Valid { res[i] = v.String } else { res[i] = "NULL" } default: panic("THIS SHOULD NEVER HAPPEN, WE GOT SERVED AN UNKNOWN TYPE") } } return res } func (m *Manager) references(cts *engine.CreateTableStatement) error { table, ok := m.GetTable(cts.TableName) if !ok { return fmt.Errorf("No table with name found, name: %v", cts.TableName) } for i, column := range cts.Columns { flags := extrasToFlags(column.Extra) if !flags.Has(FOREIGN_KEY) { continue } index := slices.IndexFunc(column.Extra, func(c string) bool { return strings.HasPrefix(c, "ref") }) refExtra := column.Extra[index] refStr := strings.Split(refExtra, " ")[1] s := strings.Split(refStr, ".") tableName := s[0] columnName := s[1] refTable, ok := m.GetTable(tableName) if !ok { fmt.Println(m.tables) return fmt.Errorf("Reference table '%v' not found", tableName) } else { colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool { return c.Name == columnName }) table.Columns[i].Reference = &refTable.Columns[colIndex] } } return nil } func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) { res := Table{ Name: cts.TableName, Columns: make([]Column, len(cts.Columns)), Rows: make([]Row, 0), } for i, column := range cts.Columns { flags := extrasToFlags(column.Extra) var columnType ColumnType switch column.Type { case "REAL": columnType = REAL case "BLOB": columnType = BLOB case "TEXT": columnType = TEXT case "INTEGER", "NUMERIC": columnType = INTEGER default: panic("This shouldnt happen") } res.Columns[i] = Column{ Type: columnType, Name: column.Name, Flags: flags, } } return res, nil } func extrasToFlags(extras []string) ColumnFlag { res := NONE for _, extra := range extras { // This is not good switch strings.Split(extra, " ")[0] { case "PRIMARY_KEY": res |= PRIMARY_KEY case "ref": res |= FOREIGN_KEY case "NOT_NULL": res |= NOT_NULL case "AUTOINCREMENT": res |= AUTO_INCREMENT default: log.Panicf("NOT IMPLEMENTED EXTRA: %v", extra) } } return res }