401 lines
8.2 KiB
Go
401 lines
8.2 KiB
Go
package engine
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
engine "git.pablu.de/pablu/sqv-engine/sql"
|
|
)
|
|
|
|
type Manager struct {
|
|
parser *engine.Parser
|
|
|
|
conn *sql.DB
|
|
tables []Table
|
|
}
|
|
|
|
func NewManagerFromFile(sqlTxt string) *Manager {
|
|
return &Manager{
|
|
parser: engine.NewParser(strings.NewReader(sqlTxt)),
|
|
}
|
|
}
|
|
|
|
func NewManager(path string) (*Manager, error) {
|
|
db, err := sql.Open("sqlite3", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var sqls []string
|
|
r, err := db.Query("SELECT name, sql FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%'")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for r.Next() {
|
|
var name, sql string
|
|
|
|
if err := r.Scan(&name, &sql); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sqls = append(sqls, sql)
|
|
}
|
|
|
|
schema := strings.Join(sqls, ";\n")
|
|
schema += ";"
|
|
|
|
return &Manager{
|
|
parser: engine.NewParser(strings.NewReader(schema)),
|
|
conn: db,
|
|
}, nil
|
|
}
|
|
|
|
func (m *Manager) Start() error {
|
|
createTableStatements := make([]*engine.CreateTableStatement, 0)
|
|
for {
|
|
stmt, err := m.parser.Parse()
|
|
if err != nil && errors.Is(err, io.EOF) {
|
|
break
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch v := stmt.(type) {
|
|
case *engine.CreateTableStatement:
|
|
t, err := m.convertCreateTableStatementToTable(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.tables = append(m.tables, t)
|
|
createTableStatements = append(createTableStatements, v)
|
|
|
|
default:
|
|
panic("NOT IMPLEMENTED")
|
|
}
|
|
}
|
|
|
|
// Rethink how to do this cleanly
|
|
for _, cts := range createTableStatements {
|
|
err := m.references(cts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) Refresh(sqlTxt string) error {
|
|
m.parser = engine.NewParser(strings.NewReader(sqlTxt))
|
|
m.tables = make([]Table, 0)
|
|
|
|
return m.Start()
|
|
}
|
|
|
|
func (m *Manager) GetTables() []Table {
|
|
return m.tables
|
|
}
|
|
|
|
func (m *Manager) GetTable(name string) (Table, bool) {
|
|
index := slices.IndexFunc(m.tables, func(t Table) bool {
|
|
return t.Name == name
|
|
})
|
|
|
|
if index < 0 {
|
|
return Table{}, false
|
|
}
|
|
table := m.tables[index]
|
|
|
|
return table, true
|
|
}
|
|
|
|
func (m *Manager) RunSql(sqlText string) (Table, error) {
|
|
p := engine.NewParser(strings.NewReader(sqlText))
|
|
stmt, err := p.Parse()
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
return Table{}, err
|
|
}
|
|
|
|
switch v := stmt.(type) {
|
|
case *engine.SelectStatement:
|
|
return m.tableFromSelectStatement(sqlText, v)
|
|
|
|
case *engine.InsertStatement:
|
|
if !slices.ContainsFunc(m.tables, func(t Table) bool {
|
|
return v.Table == t.Name
|
|
}) {
|
|
return Table{}, fmt.Errorf("Table not found")
|
|
}
|
|
|
|
res, err := m.conn.Exec(sqlText)
|
|
if err != nil {
|
|
return Table{}, err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return Table{}, err
|
|
}
|
|
return Table{}, fmt.Errorf("Rows affected: %v", affected)
|
|
|
|
case *engine.DeleteStatement:
|
|
if !slices.ContainsFunc(m.tables, func(t Table) bool {
|
|
return v.Table == t.Name
|
|
}) {
|
|
return Table{}, fmt.Errorf("Table not found")
|
|
}
|
|
|
|
res, err := m.conn.Exec(sqlText)
|
|
if err != nil {
|
|
return Table{}, err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return Table{}, err
|
|
}
|
|
return Table{}, fmt.Errorf("Rows affected: %v", affected)
|
|
|
|
default:
|
|
return Table{}, fmt.Errorf("Input statement is not of correct Syntax, select statement")
|
|
}
|
|
}
|
|
|
|
func (m *Manager) tableFromSelectStatement(sqlText string, stmt *engine.SelectStatement) (Table, error) {
|
|
table, ok := m.GetTable(stmt.From)
|
|
if !ok {
|
|
return Table{}, fmt.Errorf("Selected Table does not exist, have you perhaps misstyped the table Name?")
|
|
}
|
|
|
|
fields := make([]Column, 0)
|
|
if slices.Contains(stmt.Fields, "*") {
|
|
fields = table.Columns
|
|
} else {
|
|
for _, columnName := range stmt.Fields {
|
|
index := slices.IndexFunc(table.Columns, func(c Column) bool {
|
|
if c.Name == columnName {
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
|
|
fields = append(fields, table.Columns[index])
|
|
}
|
|
}
|
|
table.Columns = fields
|
|
|
|
err := m.loadTableRaw(&table, fields, sqlText)
|
|
if err != nil {
|
|
return Table{}, err
|
|
}
|
|
|
|
return table, nil
|
|
}
|
|
|
|
func (m *Manager) loadTableRaw(table *Table, fields []Column, s string, args ...any) error {
|
|
rows, err := m.conn.Query(s, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
table.Rows = make([]Row, 0)
|
|
|
|
for rows.Next() {
|
|
cols := make([]any, len(fields))
|
|
for i, column := range fields {
|
|
if column.Flags.Has(NOT_NULL) {
|
|
switch column.Type {
|
|
case BLOB:
|
|
cols[i] = new([]byte)
|
|
case TEXT:
|
|
cols[i] = new(string)
|
|
case INTEGER:
|
|
cols[i] = new(int)
|
|
case REAL:
|
|
cols[i] = new(float64)
|
|
default:
|
|
panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE")
|
|
}
|
|
} else {
|
|
switch column.Type {
|
|
case BLOB:
|
|
cols[i] = new([]byte)
|
|
case TEXT:
|
|
cols[i] = new(sql.NullString)
|
|
case INTEGER:
|
|
cols[i] = new(sql.NullInt64)
|
|
case REAL:
|
|
cols[i] = new(sql.NullFloat64)
|
|
default:
|
|
panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE")
|
|
}
|
|
}
|
|
}
|
|
|
|
err = rows.Scan(cols...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
table.Rows = append(table.Rows, Row{
|
|
Values: anyToStr(cols),
|
|
})
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error {
|
|
sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name)
|
|
return m.loadTableRaw(table, table.Columns, sql, maxRows)
|
|
}
|
|
|
|
func (m *Manager) LoadTable(table *Table) error {
|
|
sql := fmt.Sprintf("SELECT * FROM %v", table.Name)
|
|
return m.loadTableRaw(table, table.Columns, sql)
|
|
}
|
|
|
|
func anyToStr(a []any) []string {
|
|
res := make([]string, len(a))
|
|
|
|
for i, c := range a {
|
|
switch v := c.(type) {
|
|
case *string:
|
|
res[i] = *v
|
|
case *int:
|
|
res[i] = strconv.Itoa(*v)
|
|
case *float64:
|
|
res[i] = strconv.FormatFloat(*v, 'f', 2, 64)
|
|
case *[]byte:
|
|
buf := *v
|
|
if len(buf) > 512 {
|
|
buf = buf[0:512]
|
|
}
|
|
res[i] = hex.EncodeToString(buf)
|
|
case *sql.NullInt64:
|
|
if v.Valid {
|
|
res[i] = strconv.Itoa(int(v.Int64))
|
|
} else {
|
|
res[i] = "NULL"
|
|
}
|
|
case *sql.NullFloat64:
|
|
if v.Valid {
|
|
res[i] = strconv.FormatFloat(v.Float64, 'f', 2, 64)
|
|
} else {
|
|
res[i] = "NULL"
|
|
}
|
|
case *sql.NullString:
|
|
if v.Valid {
|
|
res[i] = v.String
|
|
} else {
|
|
res[i] = "NULL"
|
|
}
|
|
|
|
default:
|
|
panic("THIS SHOULD NEVER HAPPEN, WE GOT SERVED AN UNKNOWN TYPE")
|
|
}
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (m *Manager) references(cts *engine.CreateTableStatement) error {
|
|
table, ok := m.GetTable(cts.TableName)
|
|
if !ok {
|
|
return fmt.Errorf("No table with name found, name: %v", cts.TableName)
|
|
}
|
|
|
|
for i, column := range cts.Columns {
|
|
flags := extrasToFlags(column.Extra)
|
|
if !flags.Has(FOREIGN_KEY) {
|
|
continue
|
|
}
|
|
|
|
index := slices.IndexFunc(column.Extra, func(c string) bool {
|
|
return strings.HasPrefix(c, "ref")
|
|
})
|
|
refExtra := column.Extra[index]
|
|
refStr := strings.Split(refExtra, " ")[1]
|
|
|
|
s := strings.Split(refStr, ".")
|
|
tableName := s[0]
|
|
columnName := s[1]
|
|
|
|
refTable, ok := m.GetTable(tableName)
|
|
if !ok {
|
|
return fmt.Errorf("Reference table '%v' not found", tableName)
|
|
} else {
|
|
colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool {
|
|
return c.Name == columnName
|
|
})
|
|
|
|
table.Columns[i].Reference = &refTable.Columns[colIndex]
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) {
|
|
res := Table{
|
|
Name: cts.TableName,
|
|
Columns: make([]Column, len(cts.Columns)),
|
|
Rows: make([]Row, 0),
|
|
}
|
|
|
|
for i, column := range cts.Columns {
|
|
flags := extrasToFlags(column.Extra)
|
|
var columnType ColumnType
|
|
switch column.Type {
|
|
case "REAL":
|
|
columnType = REAL
|
|
case "BLOB":
|
|
columnType = BLOB
|
|
case "TEXT":
|
|
columnType = TEXT
|
|
case "INTEGER", "NUMERIC":
|
|
columnType = INTEGER
|
|
default:
|
|
panic("This shouldnt happen")
|
|
}
|
|
|
|
res.Columns[i] = Column{
|
|
Type: columnType,
|
|
Name: column.Name,
|
|
Flags: flags,
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func extrasToFlags(extras []string) ColumnFlag {
|
|
res := NONE
|
|
for _, extra := range extras {
|
|
|
|
// This is not good
|
|
switch strings.Split(extra, " ")[0] {
|
|
case "PRIMARY_KEY":
|
|
res |= PRIMARY_KEY
|
|
case "ref":
|
|
res |= FOREIGN_KEY
|
|
case "NOT_NULL":
|
|
res |= NOT_NULL
|
|
case "AUTOINCREMENT":
|
|
res |= AUTO_INCREMENT
|
|
default:
|
|
log.Panicf("NOT IMPLEMENTED EXTRA: %v", extra)
|
|
}
|
|
}
|
|
|
|
return res
|
|
}
|