Deduplicate code and use loadTableRaw for RunSql
This commit is contained in:
202
manager.go
202
manager.go
@@ -62,115 +62,8 @@ func NewManager(path string) (*Manager, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RunSql(sqlText string) (Table, error) {
|
|
||||||
p := engine.NewParser(strings.NewReader(sqlText))
|
|
||||||
stmt, err := p.Parse()
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
return Table{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
selectStmt, ok := stmt.(*engine.SelectStatement)
|
|
||||||
if !ok {
|
|
||||||
panic("HELP ITS NOT A SELECT STATMET")
|
|
||||||
}
|
|
||||||
|
|
||||||
table, ok := m.GetTable(selectStmt.From)
|
|
||||||
if !ok {
|
|
||||||
panic("HELP TABLE NOT FOUND")
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := m.conn.Query(sqlText)
|
|
||||||
if err != nil {
|
|
||||||
return Table{}, err
|
|
||||||
}
|
|
||||||
table.Rows = make([]Row, 0)
|
|
||||||
|
|
||||||
if slices.Contains(selectStmt.Fields, "*") {
|
|
||||||
for rows.Next() {
|
|
||||||
cols := make([]any, len(table.Columns))
|
|
||||||
|
|
||||||
for i, column := range table.Columns {
|
|
||||||
switch column.Type {
|
|
||||||
case BLOB:
|
|
||||||
cols[i] = new([]byte)
|
|
||||||
case TEXT:
|
|
||||||
cols[i] = new(string)
|
|
||||||
case INTEGER:
|
|
||||||
cols[i] = new(int)
|
|
||||||
case REAL:
|
|
||||||
cols[i] = new(float64)
|
|
||||||
default:
|
|
||||||
panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Scan(cols...)
|
|
||||||
if err != nil {
|
|
||||||
return Table{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
table.Rows = append(table.Rows, Row{
|
|
||||||
Values: anyToStr(cols),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return table, nil
|
|
||||||
} else {
|
|
||||||
columns := make([]Column, len(selectStmt.Fields))
|
|
||||||
firstTime := true
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
|
|
||||||
cols := make([]any, len(selectStmt.Fields))
|
|
||||||
|
|
||||||
for i, s := range selectStmt.Fields {
|
|
||||||
for _, column := range table.Columns {
|
|
||||||
if column.Name != s {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if firstTime {
|
|
||||||
columns[i] = column
|
|
||||||
}
|
|
||||||
|
|
||||||
switch column.Type {
|
|
||||||
case BLOB:
|
|
||||||
cols[i] = new([]byte)
|
|
||||||
case TEXT:
|
|
||||||
cols[i] = new(string)
|
|
||||||
case INTEGER:
|
|
||||||
cols[i] = new(int)
|
|
||||||
case REAL:
|
|
||||||
cols[i] = new(float64)
|
|
||||||
default:
|
|
||||||
panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
firstTime = false
|
|
||||||
|
|
||||||
err = rows.Scan(cols...)
|
|
||||||
if err != nil {
|
|
||||||
return Table{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
table.Rows = append(table.Rows, Row{
|
|
||||||
Values: anyToStr(cols),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
nTable := Table{
|
|
||||||
Name: selectStmt.From,
|
|
||||||
Columns: columns,
|
|
||||||
Rows: table.Rows,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nTable, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) Start() error {
|
func (m *Manager) Start() error {
|
||||||
|
createTableStatements := make([]*engine.CreateTableStatement, 0)
|
||||||
for {
|
for {
|
||||||
stmt, err := m.parser.Parse()
|
stmt, err := m.parser.Parse()
|
||||||
if err != nil && errors.Is(err, io.EOF) {
|
if err != nil && errors.Is(err, io.EOF) {
|
||||||
@@ -187,12 +80,21 @@ func (m *Manager) Start() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.tables = append(m.tables, t)
|
m.tables = append(m.tables, t)
|
||||||
|
createTableStatements = append(createTableStatements, v)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
panic("NOT IMPLEMENTED")
|
panic("NOT IMPLEMENTED")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rethink how to do this cleanly
|
||||||
|
for _, cts := range createTableStatements {
|
||||||
|
err := m.references(cts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +122,47 @@ func (m *Manager) GetTable(name string) (Table, bool) {
|
|||||||
return table, true
|
return table, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error {
|
func (m *Manager) RunSql(sqlText string) (Table, error) {
|
||||||
|
p := engine.NewParser(strings.NewReader(sqlText))
|
||||||
|
stmt, err := p.Parse()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return Table{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
selectStmt, ok := stmt.(*engine.SelectStatement)
|
||||||
|
if !ok {
|
||||||
|
panic("HELP ITS NOT A SELECT STATMET")
|
||||||
|
}
|
||||||
|
|
||||||
|
table, ok := m.GetTable(selectStmt.From)
|
||||||
|
if !ok {
|
||||||
|
panic("HELP TABLE NOT FOUND")
|
||||||
|
}
|
||||||
|
|
||||||
|
table.Rows = make([]Row, 0)
|
||||||
|
|
||||||
|
fields := make([]Column, 0)
|
||||||
|
if slices.Contains(selectStmt.Fields, "*") {
|
||||||
|
fields = table.Columns
|
||||||
|
} else {
|
||||||
|
for _, columnName := range selectStmt.Fields {
|
||||||
|
index := slices.IndexFunc(table.Columns, func(c Column) bool {
|
||||||
|
if c.Name == columnName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
fields = append(fields, table.Columns[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
table.Columns = fields
|
||||||
|
|
||||||
|
err = m.loadTableRaw(&table, fields, sqlText)
|
||||||
|
return table, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) loadTableRaw(table *Table, fields []Column, s string, args ...any) error {
|
||||||
rows, err := m.conn.Query(s, args...)
|
rows, err := m.conn.Query(s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -228,8 +170,8 @@ func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error {
|
|||||||
table.Rows = make([]Row, 0)
|
table.Rows = make([]Row, 0)
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
cols := make([]any, len(table.Columns))
|
cols := make([]any, len(fields))
|
||||||
for i, column := range table.Columns {
|
for i, column := range fields {
|
||||||
if column.Flags.Has(NOT_NULL) {
|
if column.Flags.Has(NOT_NULL) {
|
||||||
switch column.Type {
|
switch column.Type {
|
||||||
case BLOB:
|
case BLOB:
|
||||||
@@ -274,12 +216,12 @@ func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error {
|
|||||||
|
|
||||||
func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error {
|
func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error {
|
||||||
sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name)
|
sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name)
|
||||||
return m.loadTableRaw(table, sql, maxRows)
|
return m.loadTableRaw(table, table.Columns, sql, maxRows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) LoadTable(table *Table) error {
|
func (m *Manager) LoadTable(table *Table) error {
|
||||||
sql := fmt.Sprintf("SELECT * FROM %v", table.Name)
|
sql := fmt.Sprintf("SELECT * FROM %v", table.Name)
|
||||||
return m.loadTableRaw(table, sql)
|
return m.loadTableRaw(table, table.Columns, sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
func anyToStr(a []any) []string {
|
func anyToStr(a []any) []string {
|
||||||
@@ -326,17 +268,18 @@ func anyToStr(a []any) []string {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) {
|
func (m *Manager) references(cts *engine.CreateTableStatement) error {
|
||||||
res := Table{
|
table, ok := m.GetTable(cts.TableName)
|
||||||
Name: cts.TableName,
|
if !ok {
|
||||||
Columns: make([]Column, len(cts.Columns)),
|
return fmt.Errorf("No table with name found, name: %v", cts.TableName)
|
||||||
Rows: make([]Row, 0),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, column := range cts.Columns {
|
for i, column := range cts.Columns {
|
||||||
flags := extrasToFlags(column.Extra)
|
flags := extrasToFlags(column.Extra)
|
||||||
var ref *Column = nil
|
if !flags.Has(FOREIGN_KEY) {
|
||||||
if flags.Has(FOREIGN_KEY) {
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
index := slices.IndexFunc(column.Extra, func(c string) bool {
|
index := slices.IndexFunc(column.Extra, func(c string) bool {
|
||||||
return strings.HasPrefix(c, "ref")
|
return strings.HasPrefix(c, "ref")
|
||||||
})
|
})
|
||||||
@@ -350,18 +293,28 @@ func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStat
|
|||||||
refTable, ok := m.GetTable(tableName)
|
refTable, ok := m.GetTable(tableName)
|
||||||
if !ok {
|
if !ok {
|
||||||
fmt.Println(m.tables)
|
fmt.Println(m.tables)
|
||||||
fmt.Println(res)
|
return fmt.Errorf("Reference table '%v' not found", tableName)
|
||||||
fmt.Println("Reference was skipped")
|
|
||||||
// return Table{}, fmt.Errorf("Reference table '%v' not found", tableName)
|
|
||||||
} else {
|
} else {
|
||||||
colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool {
|
colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool {
|
||||||
return c.Name == columnName
|
return c.Name == columnName
|
||||||
})
|
})
|
||||||
|
|
||||||
ref = &refTable.Columns[colIndex]
|
table.Columns[i].Reference = &refTable.Columns[colIndex]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) {
|
||||||
|
res := Table{
|
||||||
|
Name: cts.TableName,
|
||||||
|
Columns: make([]Column, len(cts.Columns)),
|
||||||
|
Rows: make([]Row, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, column := range cts.Columns {
|
||||||
|
flags := extrasToFlags(column.Extra)
|
||||||
var columnType ColumnType
|
var columnType ColumnType
|
||||||
switch column.Type {
|
switch column.Type {
|
||||||
case "REAL":
|
case "REAL":
|
||||||
@@ -379,7 +332,6 @@ func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStat
|
|||||||
res.Columns[i] = Column{
|
res.Columns[i] = Column{
|
||||||
Type: columnType,
|
Type: columnType,
|
||||||
Name: column.Name,
|
Name: column.Name,
|
||||||
Reference: ref,
|
|
||||||
Flags: flags,
|
Flags: flags,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user