diff --git a/manager.go b/manager.go index 9de6486..e1a3d61 100644 --- a/manager.go +++ b/manager.go @@ -62,115 +62,8 @@ func NewManager(path string) (*Manager, error) { }, nil } -func (m *Manager) RunSql(sqlText string) (Table, error) { - p := engine.NewParser(strings.NewReader(sqlText)) - stmt, err := p.Parse() - if err != nil && !errors.Is(err, io.EOF) { - return Table{}, err - } - - selectStmt, ok := stmt.(*engine.SelectStatement) - if !ok { - panic("HELP ITS NOT A SELECT STATMET") - } - - table, ok := m.GetTable(selectStmt.From) - if !ok { - panic("HELP TABLE NOT FOUND") - } - - rows, err := m.conn.Query(sqlText) - if err != nil { - return Table{}, err - } - table.Rows = make([]Row, 0) - - if slices.Contains(selectStmt.Fields, "*") { - for rows.Next() { - cols := make([]any, len(table.Columns)) - - for i, column := range table.Columns { - switch column.Type { - case BLOB: - cols[i] = new([]byte) - case TEXT: - cols[i] = new(string) - case INTEGER: - cols[i] = new(int) - case REAL: - cols[i] = new(float64) - default: - panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") - } - } - - err = rows.Scan(cols...) - if err != nil { - return Table{}, err - } - - table.Rows = append(table.Rows, Row{ - Values: anyToStr(cols), - }) - } - - return table, nil - } else { - columns := make([]Column, len(selectStmt.Fields)) - firstTime := true - - for rows.Next() { - - cols := make([]any, len(selectStmt.Fields)) - - for i, s := range selectStmt.Fields { - for _, column := range table.Columns { - if column.Name != s { - continue - } - - if firstTime { - columns[i] = column - } - - switch column.Type { - case BLOB: - cols[i] = new([]byte) - case TEXT: - cols[i] = new(string) - case INTEGER: - cols[i] = new(int) - case REAL: - cols[i] = new(float64) - default: - panic("THIS SHOULD NEVER HAPPEN, WE HIT AN UNKNOWN COLUMN.TYPE") - } - } - } - - firstTime = false - - err = rows.Scan(cols...) - if err != nil { - return Table{}, err - } - - table.Rows = append(table.Rows, Row{ - Values: anyToStr(cols), - }) - } - - nTable := Table{ - Name: selectStmt.From, - Columns: columns, - Rows: table.Rows, - } - - return nTable, nil - } -} - func (m *Manager) Start() error { + createTableStatements := make([]*engine.CreateTableStatement, 0) for { stmt, err := m.parser.Parse() if err != nil && errors.Is(err, io.EOF) { @@ -187,12 +80,21 @@ func (m *Manager) Start() error { return err } m.tables = append(m.tables, t) + createTableStatements = append(createTableStatements, v) default: panic("NOT IMPLEMENTED") } } + // Rethink how to do this cleanly + for _, cts := range createTableStatements { + err := m.references(cts) + if err != nil { + return err + } + } + return nil } @@ -220,7 +122,47 @@ func (m *Manager) GetTable(name string) (Table, bool) { return table, true } -func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error { +func (m *Manager) RunSql(sqlText string) (Table, error) { + p := engine.NewParser(strings.NewReader(sqlText)) + stmt, err := p.Parse() + if err != nil && !errors.Is(err, io.EOF) { + return Table{}, err + } + + selectStmt, ok := stmt.(*engine.SelectStatement) + if !ok { + panic("HELP ITS NOT A SELECT STATMET") + } + + table, ok := m.GetTable(selectStmt.From) + if !ok { + panic("HELP TABLE NOT FOUND") + } + + table.Rows = make([]Row, 0) + + fields := make([]Column, 0) + if slices.Contains(selectStmt.Fields, "*") { + fields = table.Columns + } else { + for _, columnName := range selectStmt.Fields { + index := slices.IndexFunc(table.Columns, func(c Column) bool { + if c.Name == columnName { + return true + } + return false + }) + + fields = append(fields, table.Columns[index]) + } + } + table.Columns = fields + + err = m.loadTableRaw(&table, fields, sqlText) + return table, err +} + +func (m *Manager) loadTableRaw(table *Table, fields []Column, s string, args ...any) error { rows, err := m.conn.Query(s, args...) if err != nil { return err @@ -228,8 +170,8 @@ func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error { table.Rows = make([]Row, 0) for rows.Next() { - cols := make([]any, len(table.Columns)) - for i, column := range table.Columns { + cols := make([]any, len(fields)) + for i, column := range fields { if column.Flags.Has(NOT_NULL) { switch column.Type { case BLOB: @@ -274,12 +216,12 @@ func (m *Manager) loadTableRaw(table *Table, s string, args ...any) error { func (m *Manager) LoadTableMaxRows(table *Table, maxRows int) error { sql := fmt.Sprintf("SELECT * FROM %v LIMIT ?", table.Name) - return m.loadTableRaw(table, sql, maxRows) + return m.loadTableRaw(table, table.Columns, sql, maxRows) } func (m *Manager) LoadTable(table *Table) error { sql := fmt.Sprintf("SELECT * FROM %v", table.Name) - return m.loadTableRaw(table, sql) + return m.loadTableRaw(table, table.Columns, sql) } func anyToStr(a []any) []string { @@ -326,6 +268,44 @@ func anyToStr(a []any) []string { return res } +func (m *Manager) references(cts *engine.CreateTableStatement) error { + table, ok := m.GetTable(cts.TableName) + if !ok { + return fmt.Errorf("No table with name found, name: %v", cts.TableName) + } + + for i, column := range cts.Columns { + flags := extrasToFlags(column.Extra) + if !flags.Has(FOREIGN_KEY) { + continue + } + + index := slices.IndexFunc(column.Extra, func(c string) bool { + return strings.HasPrefix(c, "ref") + }) + refExtra := column.Extra[index] + refStr := strings.Split(refExtra, " ")[1] + + s := strings.Split(refStr, ".") + tableName := s[0] + columnName := s[1] + + refTable, ok := m.GetTable(tableName) + if !ok { + fmt.Println(m.tables) + return fmt.Errorf("Reference table '%v' not found", tableName) + } else { + colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool { + return c.Name == columnName + }) + + table.Columns[i].Reference = &refTable.Columns[colIndex] + } + } + + return nil +} + func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStatement) (Table, error) { res := Table{ Name: cts.TableName, @@ -335,33 +315,6 @@ func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStat for i, column := range cts.Columns { flags := extrasToFlags(column.Extra) - var ref *Column = nil - if flags.Has(FOREIGN_KEY) { - index := slices.IndexFunc(column.Extra, func(c string) bool { - return strings.HasPrefix(c, "ref") - }) - refExtra := column.Extra[index] - refStr := strings.Split(refExtra, " ")[1] - - s := strings.Split(refStr, ".") - tableName := s[0] - columnName := s[1] - - refTable, ok := m.GetTable(tableName) - if !ok { - fmt.Println(m.tables) - fmt.Println(res) - fmt.Println("Reference was skipped") - // return Table{}, fmt.Errorf("Reference table '%v' not found", tableName) - } else { - colIndex := slices.IndexFunc(refTable.Columns, func(c Column) bool { - return c.Name == columnName - }) - - ref = &refTable.Columns[colIndex] - } - } - var columnType ColumnType switch column.Type { case "REAL": @@ -377,10 +330,9 @@ func (m *Manager) convertCreateTableStatementToTable(cts *engine.CreateTableStat } res.Columns[i] = Column{ - Type: columnType, - Name: column.Name, - Reference: ref, - Flags: flags, + Type: columnType, + Name: column.Name, + Flags: flags, } }