package engine import ( "database/sql" "encoding/hex" "errors" "fmt" "io" "log" "slices" "strconv" "strings" _ "github.com/mattn/go-sqlite3" engine "git.pablu.de/pablu/sqv-engine/sql" ) type Manager struct { parser *engine.Parser conn *sql.DB tables []Table } func NewManagerFromFile(sqlTxt string) *Manager { return &Manager{ parser: engine.NewParser(strings.NewReader(sqlTxt)), } } func NewManager(path string) (*Manager, error) { db, err := sql.Open("sqlite3", path) if err != nil { return nil, err } var sqls []string r, err := db.Query("SELECT name, sql FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%'") if err != nil { return nil, err } for r.Next() { var name, sql string if err := r.Scan(&name, &sql); err != nil { return nil, err } sqls = append(sqls, sql) } schema := strings.Join(sqls, ";") schema += ";" // fmt.Println(schema) return &Manager{ parser: engine.NewParser(strings.NewReader(schema)), conn: db, }, nil } func (m *Manager) RunSql(sqlText string) (Table, error) { p := engine.NewParser(strings.NewReader(sqlText)) stmt, err := p.Parse() if err != nil && !errors.Is(err, io.EOF) { return Table{}, err } selectStmt, ok := stmt.(*engine.SelectStatement) if !ok { panic("HELP ITS NOT A SELECT STATMET") } table, ok := m.GetTable(selectStmt.From) if !ok { panic("HELP TABLE NOT FOUND") } rows, err := m.conn.Query(sqlText) if err != nil { return Table{}, err } table.Rows = make([]Row, 0) if slices.Contains(selectStmt.Fields, "*") { for rows.Next() { cols := make([]any, len(table.Columns)) for i, column := range table.Columns { switch column.Type { case BLOB: cols[i] = new([]byte) case TEXT: cols[i] = new(string) case INTEGER: cols[i] = new(int) case REAL: cols[i] = new(float64) default: panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") } } err = rows.Scan(cols...) if err != nil { return Table{}, err } table.Rows = append(table.Rows, Row{ Values: anyToStr(cols), }) } return table, nil } else { columns := make([]Column, len(selectStmt.Fields)) firstTime := true for rows.Next() { cols := make([]any, len(selectStmt.Fields)) for i, s := range selectStmt.Fields { for _, column := range table.Columns { if column.Name != s { continue } if firstTime { columns[i] = column } switch column.Type { case BLOB: cols[i] = new([]byte) case TEXT: cols[i] = new(string) case INTEGER: cols[i] = new(int) case REAL: cols[i] = new(float64) default: panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") } } } firstTime = false err = rows.Scan(cols...) if err != nil { return Table{}, err } table.Rows = append(table.Rows, Row{ Values: anyToStr(cols), }) } nTable := Table{ Name: selectStmt.From, Columns: columns, Rows: table.Rows, } return nTable, nil } } func (m *Manager) Start() error { for { stmt, err := m.parser.Parse() if err != nil && errors.Is(err, io.EOF) { fmt.Println("Finished parsing") break } else if err != nil { return err } switch v := stmt.(type) { case *engine.CreateTableStatement: t, err := m.convertCreateTableStatementToTable(v) if err != nil { return err } m.tables = append(m.tables, t) default: panic("NOT IMPLEMENTED") } } return nil } func (m *Manager) Refresh(sqlTxt string) error { m.parser = engine.NewParser(strings.NewReader(sqlTxt)) m.tables = make([]Table, 0) return m.Start() } func (m *Manager) GetTables() []Table { return m.tables } func (m *Manager) GetTable(name string) (Table, bool) { index := slices.IndexFunc(m.tables, func(t Table) bool { return t.Name == name }) if index < 0 { return Table{}, false } table := m.tables[index] return table, true } func (m *Manager) LoadTable(table *Table) error { rows, err := m.conn.Query(fmt.Sprintf("SELECT * FROM %v", table.Name)) if err != nil { return err } table.Rows = make([]Row, 0) for rows.Next() { cols := make([]any, len(table.Columns)) for i, column := range table.Columns { switch column.Type { case BLOB: cols[i] = new([]byte) case TEXT: cols[i] = new(string) case INTEGER: cols[i] = new(int) case REAL: cols[i] = new(float64) default: panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") } } err = rows.Scan(cols...) if err != nil { return err } table.Rows = append(table.Rows, Row{ Values: anyToStr(cols), }) } return nil } func anyToStr(a []any) []string { res := make([]string, len(a)) for i, c := range a { switch v := c.(type) { case *string: res[i] = *v case *int: res[i] = strconv.Itoa(*v) case *float64: res[i] = strconv.FormatFloat(*v, 'f', 2, 64) case *[]byte: res[i] = hex.EncodeToString(*v) default: panic("THIS SHOULD NEVER HAPPEN, WE GOT SERVED AN UNKNOWN TYPE") } } return res } func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) { res := Table{ Name: cts.TableName, Columns: make([]Column, len(cts.Columns)), Rows: make([]Row, 0), } for i, column := range cts.Columns { flags := extrasToFlags(column.Extra) var ref *Column = nil if flags.Has(FOREIGN_KEY) { index := slices.IndexFunc(column.Extra, func(c string) bool { return strings.HasPrefix(c, "ref") }) refExtra := column.Extra[index] refStr := strings.Split(refExtra, " ")[1] s := strings.Split(refStr, ".") tableName := s[0] columnName := s[1] refTable, ok := m.GetTable(tableName) if !ok { fmt.Println(m.tables) return Table{}, fmt.Errorf("Reference table '%v' not found", tableName) } colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool { return c.Name == columnName }) ref = &refTable.Columns[colIndex] } var columnType ColumnType switch column.Type { case "REAL": columnType = REAL case "BLOB": columnType = BLOB case "TEXT": columnType = TEXT case "INTEGER": columnType = INTEGER default: panic("This shouldnt happen") } res.Columns[i] = Column{ Type: columnType, Name: column.Name, Reference: ref, Flags: flags, } } return res, nil } func extrasToFlags(extras []string) ColumnFlag { res := NONE for _, extra := range extras { // This is not good switch strings.Split(extra, " ")[0] { case "PRIMARY_KEY": res |= PRIMARY_KEY case "ref": res |= FOREIGN_KEY case "NOT_NULL": res |= NOT_NULL default: log.Panicf("NOT IMPLEMENTED EXTRA: %v", extra) } } return res }