Deduplicate code and use loadTableRaw for RunSql

This commit is contained in:
Pablu
2025-12-02 10:10:57 +01:00
parent 9c51424bf8
commit b7147d03c2

View File

@@ -62,115 +62,8 @@ func NewManager(path string) (*Manager, error) {
}, nil }, nil
} }
func (m *Manager) RunSql(sqlText string) (Table, error) {
p := engine.NewParser(strings.NewReader(sqlText))
stmt, err := p.Parse()
if err != nil && !errors.Is(err, io.EOF) {
return Table{}, err
}
selectStmt, ok := stmt.(*engine.SelectStatement)
if !ok {
panic("HELP ITS NOT A SELECT STATMET")
}
table, ok := m.GetTable(selectStmt.From)
if !ok {
panic("HELP TABLE NOT FOUND")
}
rows, err := m.conn.Query(sqlText)
if err != nil {
return Table{}, err
}
table.Rows = make([]Row, 0)
if slices.Contains(selectStmt.Fields, "*") {
for rows.Next() {
cols := make([]any, len(table.Columns))
for i, column := range table.Columns {
switch column.Type {
case BLOB:
cols[i] = new([]byte)
case TEXT:
cols[i] = new(string)
case INTEGER:
cols[i] = new(int)
case REAL:
cols[i] = new(float64)
default:
panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE")
}
}
err = rows.Scan(cols...)
if err != nil {
return Table{}, err
}
table.Rows = append(table.Rows, Row{
Values: anyToStr(cols),
})
}
return table, nil
} else {
columns := make([]Column, len(selectStmt.Fields))
firstTime := true
for rows.Next() {
cols := make([]any, len(selectStmt.Fields))
for i, s := range selectStmt.Fields {
for _, column := range table.Columns {
if column.Name != s {
continue
}
if firstTime {
columns[i] = column
}
switch column.Type {
case BLOB:
cols[i] = new([]byte)
case TEXT:
cols[i] = new(string)
case INTEGER:
cols[i] = new(int)
case REAL:
cols[i] = new(float64)
default:
panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE")
}
}
}
firstTime = false
err = rows.Scan(cols...)
if err != nil {
return Table{}, err
}
table.Rows = append(table.Rows, Row{
Values: anyToStr(cols),
})
}
nTable := Table{
Name: selectStmt.From,
Columns: columns,
Rows: table.Rows,
}
return nTable, nil
}
}
func (m *Manager) Start() error { func (m *Manager) Start() error {
createTableStatements := make([]*engine.CreateTableStatement, 0)
for { for {
stmt, err := m.parser.Parse() stmt, err := m.parser.Parse()
if err != nil && errors.Is(err, io.EOF) { if err != nil && errors.Is(err, io.EOF) {
@@ -187,12 +80,21 @@ func (m *Manager) Start() error {
return err return err
} }
m.tables = append(m.tables, t) m.tables = append(m.tables, t)
createTableStatements = append(createTableStatements, v)
default: default:
panic("NOT IMPLEMENTED") panic("NOT IMPLEMENTED")
} }
} }
// Rethink how to do this cleanly
for _, cts := range createTableStatements {
err := m.references(cts)
if err != nil {
return err
}
}
return nil return nil
} }
@@ -220,7 +122,47 @@ func (m *Manager) GetTable(name string) (Table, bool) {
return table, true return table, true
} }
func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error { func (m *Manager) RunSql(sqlText string) (Table, error) {
p := engine.NewParser(strings.NewReader(sqlText))
stmt, err := p.Parse()
if err != nil && !errors.Is(err, io.EOF) {
return Table{}, err
}
selectStmt, ok := stmt.(*engine.SelectStatement)
if !ok {
panic("HELP ITS NOT A SELECT STATMET")
}
table, ok := m.GetTable(selectStmt.From)
if !ok {
panic("HELP TABLE NOT FOUND")
}
table.Rows = make([]Row, 0)
fields := make([]Column, 0)
if slices.Contains(selectStmt.Fields, "*") {
fields = table.Columns
} else {
for _, columnName := range selectStmt.Fields {
index := slices.IndexFunc(table.Columns, func(c Column) bool {
if c.Name == columnName {
return true
}
return false
})
fields = append(fields, table.Columns[index])
}
}
table.Columns = fields
err = m.loadTableRaw(&table, fields, sqlText)
return table, err
}
func (m *Manager) loadTableRaw(table *Table, fields []Column, s string, args ...any) error {
rows, err := m.conn.Query(s, args...) rows, err := m.conn.Query(s, args...)
if err != nil { if err != nil {
return err return err
@@ -228,8 +170,8 @@ func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error {
table.Rows = make([]Row, 0) table.Rows = make([]Row, 0)
for rows.Next() { for rows.Next() {
cols := make([]any, len(table.Columns)) cols := make([]any, len(fields))
for i, column := range table.Columns { for i, column := range fields {
if column.Flags.Has(NOT_NULL) { if column.Flags.Has(NOT_NULL) {
switch column.Type { switch column.Type {
case BLOB: case BLOB:
@@ -274,12 +216,12 @@ func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error {
func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error { func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error {
sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name) sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name)
return m.loadTableRaw(table, sql, maxRows) return m.loadTableRaw(table, table.Columns, sql, maxRows)
} }
func (m *Manager) LoadTable(table *Table) error { func (m *Manager) LoadTable(table *Table) error {
sql := fmt.Sprintf("SELECT * FROM %v", table.Name) sql := fmt.Sprintf("SELECT * FROM %v", table.Name)
return m.loadTableRaw(table, sql) return m.loadTableRaw(table, table.Columns, sql)
} }
func anyToStr(a []any) []string { func anyToStr(a []any) []string {
@@ -326,17 +268,18 @@ func anyToStr(a []any) []string {
return res return res
} }
func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) { func (m *Manager) references(cts *engine.CreateTableStatement) error {
res := Table{ table, ok := m.GetTable(cts.TableName)
Name: cts.TableName, if !ok {
Columns: make([]Column, len(cts.Columns)), return fmt.Errorf("No table with name found, name: %v", cts.TableName)
Rows: make([]Row, 0),
} }
for i, column := range cts.Columns { for i, column := range cts.Columns {
flags := extrasToFlags(column.Extra) flags := extrasToFlags(column.Extra)
var ref *Column = nil if !flags.Has(FOREIGN_KEY) {
if flags.Has(FOREIGN_KEY) { continue
}
index := slices.IndexFunc(column.Extra, func(c string) bool { index := slices.IndexFunc(column.Extra, func(c string) bool {
return strings.HasPrefix(c, "ref") return strings.HasPrefix(c, "ref")
}) })
@@ -350,18 +293,28 @@ func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStat
refTable, ok := m.GetTable(tableName) refTable, ok := m.GetTable(tableName)
if !ok { if !ok {
fmt.Println(m.tables) fmt.Println(m.tables)
fmt.Println(res) return fmt.Errorf("Reference table '%v' not found", tableName)
fmt.Println("Reference was skipped")
// return Table{}, fmt.Errorf("Reference table '%v' not found", tableName)
} else { } else {
colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool { colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool {
return c.Name == columnName return c.Name == columnName
}) })
ref = &refTable.Columns[colIndex] table.Columns[i].Reference = &refTable.Columns[colIndex]
} }
} }
return nil
}
func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) {
res := Table{
Name: cts.TableName,
Columns: make([]Column, len(cts.Columns)),
Rows: make([]Row, 0),
}
for i, column := range cts.Columns {
flags := extrasToFlags(column.Extra)
var columnType ColumnType var columnType ColumnType
switch column.Type { switch column.Type {
case "REAL": case "REAL":
@@ -379,7 +332,6 @@ func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStat
res.Columns[i] = Column{ res.Columns[i] = Column{
Type: columnType, Type: columnType,
Name: column.Name, Name: column.Name,
Reference: ref,
Flags: flags, Flags: flags,
} }
} }