diff --git a/TASKS.md b/TASKS.md new file mode 100644 index 0000000..d175c9a --- /dev/null +++ b/TASKS.md @@ -0,0 +1,5 @@ +[] Make lexer detect Text by quotes are same as ident? +[] Make lexer understand numbers +[] Add boolean and NULL types +[] Handle extra fields like WHERE, ORDER etc, in parser, for Delete and select +[] Think about a better way than to return an error with Affected Rows when inserting or deleting diff --git a/cmd/sqv-tview/main.go b/cmd/sqv-tview/main.go index ec3a0bb..8a801a1 100644 --- a/cmd/sqv-tview/main.go +++ b/cmd/sqv-tview/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "log" engine "git.pablu.de/pablu/sqv-engine" @@ -24,7 +25,13 @@ func populateTable(tableView *tview.Table, table engine.Table) { tableView.ScrollToBeginning() } +var ( + dbFile = flag.String("path", "db.sqlite", "Use to set db path") +) + func main() { + flag.Parse() + app := tview.NewApplication() menuView := tview.NewList() @@ -49,7 +56,7 @@ func main() { AddItem(verticalFlex, 0, 4, true). AddItem(sqlEditor, 0, 1, false) - m, err := engine.NewManager("db.sqlite") + m, err := engine.NewManager(*dbFile) if err != nil { log.Fatalf("Ran into an error on opening Manager, err: %v\n", err) } diff --git a/manager.go b/manager.go index df9aea4..6862229 100644 --- a/manager.go +++ b/manager.go @@ -126,21 +126,60 @@ func (m *Manager) RunSql(sqlText string) (Table, error) { return Table{}, err } - selectStmt, ok := stmt.(*engine.SelectStatement) - if !ok { + switch v := stmt.(type) { + case *engine.SelectStatement: + return m.tableFromSelectStatement(sqlText, v) + + case *engine.InsertStatement: + if !slices.ContainsFunc(m.tables, func(t Table) bool { + return v.Table == t.Name + }) { + return Table{}, fmt.Errorf("Table not found") + } + + res, err := m.conn.Exec(sqlText) + if err != nil { + return Table{}, err + } + affected, err := res.RowsAffected() + if err != nil { + return Table{}, err + } + return Table{}, fmt.Errorf("Rows affected: %v", affected) + + case *engine.DeleteStatement: + if !slices.ContainsFunc(m.tables, func(t Table) bool { + return v.Table == t.Name + }) { + return Table{}, fmt.Errorf("Table not found") + } + + res, err := m.conn.Exec(sqlText) + if err != nil { + return Table{}, err + } + affected, err := res.RowsAffected() + if err != nil { + return Table{}, err + } + return Table{}, fmt.Errorf("Rows affected: %v", affected) + + default: return Table{}, fmt.Errorf("Input statement is not of correct Syntax, select statement") } +} - table, ok := m.GetTable(selectStmt.From) +func (m *Manager) tableFromSelectStatement(sqlText string, stmt *engine.SelectStatement) (Table, error) { + table, ok := m.GetTable(stmt.From) if !ok { return Table{}, fmt.Errorf("Selected Table does not exist, have you perhaps misstyped the table Name?") } fields := make([]Column, 0) - if slices.Contains(selectStmt.Fields, "*") { + if slices.Contains(stmt.Fields, "*") { fields = table.Columns } else { - for _, columnName := range selectStmt.Fields { + for _, columnName := range stmt.Fields { index := slices.IndexFunc(table.Columns, func(c Column) bool { if c.Name == columnName { return true @@ -153,7 +192,7 @@ func (m *Manager) RunSql(sqlText string) (Table, error) { } table.Columns = fields - err = m.loadTableRaw(&table, fields, sqlText) + err := m.loadTableRaw(&table, fields, sqlText) if err != nil { return Table{}, err } diff --git a/sql/ast.go b/sql/ast.go index 7504a6d..d19da23 100644 --- a/sql/ast.go +++ b/sql/ast.go @@ -33,5 +33,17 @@ type SelectStatement struct { Fields []string } +type InsertStatement struct { + Table string + Values map[string]any +} + +type DeleteStatement struct { + Table string + Extra []string +} + func (_ *CreateTableStatement) isEnumValue() {} func (_ *SelectStatement) isEnumValue() {} +func (_ *InsertStatement) isEnumValue() {} +func (_ *DeleteStatement) isEnumValue() {} diff --git a/sql/lexer.go b/sql/lexer.go index 7c51200..6636ac0 100644 --- a/sql/lexer.go +++ b/sql/lexer.go @@ -29,6 +29,11 @@ const ( CREATE TABLE + INSERT + INTO + VALUES + RETURNING + SELECT FROM WHERE @@ -37,6 +42,8 @@ const ( ORDER TOP + DELETE + PRIMARY FOREIGN REFERENCES @@ -80,6 +87,11 @@ var keywords map[string]Token = map[string]Token{ "AUTOINCREMENT": AUTOINCREMENT, "CONSTRAINT": CONSTRAINT, "NUMERIC": NUMERIC, + "INSERT": INSERT, + "INTO": INTO, + "VALUES": VALUES, + "RETURNING": RETURNING, + "DELETE": DELETE, } type Position struct { diff --git a/sql/parseCreateStatement.go b/sql/parseCreateStatement.go new file mode 100644 index 0000000..79101e8 --- /dev/null +++ b/sql/parseCreateStatement.go @@ -0,0 +1,219 @@ +package sql + +import ( + "fmt" + "slices" +) + +func (p *Parser) parseCreateTable() (*CreateTableStatement, error) { + if !p.expectNext(TABLE) { + return nil, p.unexpectedToken() + } + + tok, ok := p.expectOne(QUOTE, SINGLE_QUOTE, BACKQUOTE, IDENT, IF) + if !ok { + return nil, p.unexpectedToken(IDENT, IF) + } + + switch tok { + case IF: + if !p.expectSequence(NOT, EXISTS) { + return nil, p.unexpectedToken() + } + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + fallthrough + case QUOTE, SINGLE_QUOTE, BACKQUOTE: + if !p.expectNext(IDENT) { + return nil, p.unexpectedToken() + } + } + _, _, lit := p.rescan() + + stmt := CreateTableStatement{ + TableName: lit, + Columns: make([]Column, 0), + } + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + if !p.expectNext(LPAREN) { + return nil, p.unexpectedToken(LPAREN) + } + + for { + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + _, tok, _ := p.scan() + + switch tok { + case RPAREN: + if !p.expectNext(SEMI) { + return nil, p.unexpectedToken(SEMI) + } + return &stmt, nil + + case IDENT: + column, err := p.parseColumn() + if err != nil { + return nil, err + } + stmt.Columns = append(stmt.Columns, column) + + // TODO: HANDLE AND SAVE CONSTRAINTS + case CONSTRAINT: + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + if !p.expectNext(IDENT) { + return nil, p.unexpectedToken() + } + // _, _, constraintName := p.rescan() + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + case FOREIGN: + if !p.expectSequence(KEY, LPAREN) { + return nil, p.unexpectedToken() + } + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if !p.expectNext(IDENT) { + return nil, p.unexpectedToken() + } + _, _, columnName := p.rescan() + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if !p.expectSequence(RPAREN, REFERENCES) { + return nil, p.unexpectedToken() + } + + ref, err := p.references() + if err != nil { + return nil, err + } + + column := slices.IndexFunc(stmt.Columns, func(c Column) bool { + return c.Name == columnName + }) + + stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) + + case PRIMARY: + if !p.expectSequence(KEY, LPAREN) { + return nil, p.unexpectedToken() + } + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if !p.expectNext(IDENT) { + return nil, p.unexpectedToken() + } + + primaryKeyNames := make([]string, 0) + _, _, columnName := p.rescan() + primaryKeyNames = append(primaryKeyNames, columnName) + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + for { + tok, ok := p.expectOne(RPAREN, COMMA) + if !ok { + return nil, p.unexpectedToken() + } + if tok == RPAREN { + break + } + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + if !p.expectNext(IDENT) { + return nil, p.unexpectedToken() + } + + _, _, columnName := p.rescan() + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + primaryKeyNames = append(primaryKeyNames, columnName) + } + + for _, pkName := range primaryKeyNames { + column := slices.IndexFunc(stmt.Columns, func(c Column) bool { + return c.Name == pkName + }) + stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, "PRIMARY_KEY") + } + + case COMMA: + continue + + default: + return nil, p.unexpectedToken(IDENT, RPAREN, FOREIGN, COMMA) + } + } +} + +func (p *Parser) parseColumn() (Column, error) { + _, _, lit := p.rescan() + column := Column{Name: lit, Extra: make([]string, 0)} + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if _, ok := p.expectOne(TEXT, INTEGER, REAL, BLOB, NUMERIC); !ok { + return Column{}, p.unexpectedToken(TEXT, INTEGER, REAL, BLOB, NUMERIC) + } + _, _, column.Type = p.rescan() + + for { + _, tok, lit := p.scan() + switch tok { + case COMMA: + return column, nil + case RPAREN: + p.unscan() + return column, nil + + case PRIMARY: + fallthrough + case NOT: + if _, ok := p.expectOne(NULL, KEY); !ok { + return Column{}, p.unexpectedToken(NULL, KEY) + } + _, _, rlit := p.rescan() + column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, rlit)) + + case REFERENCES: + ref, err := p.references() + if err != nil { + return Column{}, err + } + column.Extra = append(column.Extra, ref) + + case AUTOINCREMENT: + column.Extra = append(column.Extra, "AUTOINCREMENT") + + default: + return Column{}, p.unexpectedToken(COMMA, RPAREN, PRIMARY, NOT, REFERENCES) + } + } +} + +func (p *Parser) references() (string, error) { + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + if !p.expectNext(IDENT) { + return "", p.unexpectedToken(IDENT) + } + _, _, referenceTableName := p.rescan() + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if !p.expectNext(LPAREN) { + return "", p.unexpectedToken() + } + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if !p.expectNext(IDENT) { + return "", p.unexpectedToken() + } + _, _, referenceColumnName := p.rescan() + + p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) + + if !p.expectNext(RPAREN) { + return "", p.unexpectedToken(RPAREN) + } + + return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil +} diff --git a/sql/parseDeleteStatement.go b/sql/parseDeleteStatement.go new file mode 100644 index 0000000..cd576c1 --- /dev/null +++ b/sql/parseDeleteStatement.go @@ -0,0 +1,15 @@ +package sql + +func (p *Parser) parseDelete() (*DeleteStatement, error) { + if !p.expectSequence(FROM, IDENT) { + return nil, p.unexpectedToken(INTO) + } + + res := DeleteStatement{} + + _, _, res.Table = p.rescan() + + p.consumeUntilOne(50, EOF, SEMI) + + return &res, nil +} diff --git a/sql/parseInsertStatement.go b/sql/parseInsertStatement.go new file mode 100644 index 0000000..276525d --- /dev/null +++ b/sql/parseInsertStatement.go @@ -0,0 +1,69 @@ +package sql + +import "fmt" + +func (p *Parser) parseInsert() (*InsertStatement, error) { + if !p.expectSequence(INTO, IDENT) { + return nil, p.unexpectedToken(INTO) + } + + res := InsertStatement{} + + _, _, res.Table = p.rescan() + + if !p.expectNext(LPAREN) { + return nil, p.unexpectedToken(LPAREN) + } + + fieldNames := make([]string, 0) + + for loop := true; loop; { + _, tok, val := p.scan() + switch tok { + case IDENT: + fieldNames = append(fieldNames, val) + case RPAREN: + loop = false + case COMMA: + continue + default: + return nil, p.unexpectedToken(IDENT, RPAREN, COMMA) + } + } + + if !p.expectSequence(VALUES, LPAREN) { + return nil, p.unexpectedToken() + } + + values := make([]any, 0, len(fieldNames)) + for loop := true; loop; { + _, tok, val := p.scan() + switch tok { + case IDENT: + // TODO, convert to actual datatype? + values = append(values, val) + case COMMA, QUOTE, SINGLE_QUOTE, BACKQUOTE: + continue + case RPAREN: + loop = false + default: + return nil, p.unexpectedToken(IDENT, RPAREN, COMMA) + } + } + + if len(values) != len(fieldNames) { + return nil, fmt.Errorf("Expected same amount of Values as Fields, but got %v fields, and %v values", fieldNames, values) + } + + // Handle things like RETURNING *, also handle multiple Values + if !p.consumeUntilOne(50, SEMI, EOF) { + return nil, fmt.Errorf("Expected semicolon but never found after 50 tries") + } + + res.Values = make(map[string]any) + for i, name := range fieldNames { + res.Values[name] = values[i] + } + + return &res, nil +} diff --git a/sql/parseSelectStatement.go b/sql/parseSelectStatement.go new file mode 100644 index 0000000..adc7cbf --- /dev/null +++ b/sql/parseSelectStatement.go @@ -0,0 +1,47 @@ +package sql + +import "fmt" + +func (p *Parser) parseSelect() (*SelectStatement, error) { + tok, ok := p.expectOne(ASTERIKS, IDENT) + if !ok { + return nil, p.unexpectedToken(ASTERIKS, IDENT) + } + + fields := make([]string, 1) + fields[0] = "*" + if tok == IDENT { + _, _, n := p.rescan() + fields[0] = n + for { + tok, ok := p.expectOne(COMMA, FROM) + if !ok { + return nil, p.unexpectedToken(COMMA, FROM) + } + if tok == FROM { + p.unscan() + break + } + + if !p.expectNext(IDENT) { + return nil, p.unexpectedToken(IDENT) + } + _, _, n := p.rescan() + fields = append(fields, n) + } + } + + if !p.expectSequence(FROM, IDENT) { + return nil, p.unexpectedToken() + } + + _, _, tableName := p.rescan() + if !p.consumeUntilOne(50, SEMI, EOF) { + return nil, fmt.Errorf("Expected semicolon but never found after 50 tries") + } + + return &SelectStatement{ + From: tableName, + Fields: fields, + }, nil +} diff --git a/sql/parser.go b/sql/parser.go index dc7aa79..8d4c5a4 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -23,9 +23,9 @@ func NewParser(r io.Reader) *Parser { } func (p *Parser) Parse() (Statement, error) { - tok, ok := p.expectOne(CREATE, EOF, SELECT) + tok, ok := p.expectOne(CREATE, EOF, SELECT, INSERT, DELETE) if !ok { - return nil, p.unexpectedToken(CREATE, EOF) + return nil, p.unexpectedToken(CREATE, EOF, SELECT, INSERT, DELETE) } else if tok == EOF { return nil, io.EOF } @@ -37,240 +37,15 @@ func (p *Parser) Parse() (Statement, error) { return p.parseCreateTable() case SELECT: return p.parseSelect() + case INSERT: + return p.parseInsert() + case DELETE: + return p.parseDelete() default: panic("SHOULD NEVER BE REACHED") } } -func (p *Parser) parseSelect() (*SelectStatement, error) { - tok, ok := p.expectOne(ASTERIKS, IDENT) - if !ok { - return nil, p.unexpectedToken(ASTERIKS, IDENT) - } - - fields := make([]string, 1) - fields[0] = "*" - if tok == IDENT { - _, _, n := p.rescan() - fields[0] = n - for { - tok, ok := p.expectOne(COMMA, FROM) - if !ok { - return nil, p.unexpectedToken(COMMA, FROM) - } - if tok == FROM { - p.unscan() - break - } - - if !p.expectNext(IDENT) { - return nil, p.unexpectedToken(IDENT) - } - _, _, n := p.rescan() - fields = append(fields, n) - } - } - - if !p.expectSequence(FROM, IDENT) { - return nil, p.unexpectedToken() - } - - _, _, tableName := p.rescan() - if !p.consumeUntilOne(50, SEMI, EOF) { - return nil, fmt.Errorf("Expected semicolon but never found after 50 tries") - } - - return &SelectStatement{ - From: tableName, - Fields: fields, - }, nil -} - -func (p *Parser) parseCreateTable() (*CreateTableStatement, error) { - if !p.expectNext(TABLE) { - return nil, p.unexpectedToken() - } - - tok, ok := p.expectOne(QUOTE, SINGLE_QUOTE, BACKQUOTE, IDENT, IF) - if !ok { - return nil, p.unexpectedToken(IDENT, IF) - } - - switch tok { - case IF: - if !p.expectSequence(NOT, EXISTS) { - return nil, p.unexpectedToken() - } - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - fallthrough - case QUOTE, SINGLE_QUOTE, BACKQUOTE: - if !p.expectNext(IDENT) { - return nil, p.unexpectedToken() - } - } - _, _, lit := p.rescan() - - stmt := CreateTableStatement{ - TableName: lit, - Columns: make([]Column, 0), - } - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - if !p.expectNext(LPAREN) { - return nil, p.unexpectedToken(LPAREN) - } - - for { - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - _, tok, _ := p.scan() - - switch tok { - case RPAREN: - if !p.expectNext(SEMI) { - return nil, p.unexpectedToken(SEMI) - } - return &stmt, nil - - case IDENT: - column, err := p.parseColumn() - if err != nil { - return nil, err - } - stmt.Columns = append(stmt.Columns, column) - - // TODO: HANDLE AND SAVE CONSTRAINTS - case CONSTRAINT: - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - if !p.expectNext(IDENT) { - return nil, p.unexpectedToken() - } - // _, _, constraintName := p.rescan() - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - case FOREIGN: - if !p.expectSequence(KEY, LPAREN) { - return nil, p.unexpectedToken() - } - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if !p.expectNext(IDENT) { - return nil, p.unexpectedToken() - } - _, _, columnName := p.rescan() - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if !p.expectSequence(RPAREN, REFERENCES) { - return nil, p.unexpectedToken() - } - - ref, err := p.references() - if err != nil { - return nil, err - } - - column := slices.IndexFunc(stmt.Columns, func(c Column) bool { - return c.Name == columnName - }) - - stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) - - case PRIMARY: - if !p.expectSequence(KEY, LPAREN) { - return nil, p.unexpectedToken() - } - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if !p.expectNext(IDENT) { - return nil, p.unexpectedToken() - } - - primaryKeyNames := make([]string, 0) - _, _, columnName := p.rescan() - primaryKeyNames = append(primaryKeyNames, columnName) - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - for { - tok, ok := p.expectOne(RPAREN, COMMA) - if !ok { - return nil, p.unexpectedToken() - } - if tok == RPAREN { - break - } - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - if !p.expectNext(IDENT) { - return nil, p.unexpectedToken() - } - - _, _, columnName := p.rescan() - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - primaryKeyNames = append(primaryKeyNames, columnName) - } - - for _, pkName := range primaryKeyNames { - column := slices.IndexFunc(stmt.Columns, func(c Column) bool { - return c.Name == pkName - }) - stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, "PRIMARY_KEY") - } - - case COMMA: - continue - - default: - return nil, p.unexpectedToken(IDENT, RPAREN, FOREIGN, COMMA) - } - } -} - -func (p *Parser) parseColumn() (Column, error) { - _, _, lit := p.rescan() - column := Column{Name: lit, Extra: make([]string, 0)} - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if _, ok := p.expectOne(TEXT, INTEGER, REAL, BLOB, NUMERIC); !ok { - return Column{}, p.unexpectedToken(TEXT, INTEGER, REAL, BLOB, NUMERIC) - } - _, _, column.Type = p.rescan() - - for { - _, tok, lit := p.scan() - switch tok { - case COMMA: - return column, nil - case RPAREN: - p.unscan() - return column, nil - - case PRIMARY: - fallthrough - case NOT: - if _, ok := p.expectOne(NULL, KEY); !ok { - return Column{}, p.unexpectedToken(NULL, KEY) - } - _, _, rlit := p.rescan() - column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, rlit)) - - case REFERENCES: - ref, err := p.references() - if err != nil { - return Column{}, err - } - column.Extra = append(column.Extra, ref) - - case AUTOINCREMENT: - column.Extra = append(column.Extra, "AUTOINCREMENT") - - default: - return Column{}, p.unexpectedToken(COMMA, RPAREN, PRIMARY, NOT, REFERENCES) - } - } -} - func (p *Parser) unexpectedToken(expected ...Token) error { l := len(expected) pos, tok, lit := p.rescan() @@ -295,34 +70,6 @@ func (p *Parser) unexpectedToken(expected ...Token) error { } } -func (p *Parser) references() (string, error) { - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - if !p.expectNext(IDENT) { - return "", p.unexpectedToken(IDENT) - } - _, _, referenceTableName := p.rescan() - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if !p.expectNext(LPAREN) { - return "", p.unexpectedToken() - } - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if !p.expectNext(IDENT) { - return "", p.unexpectedToken() - } - _, _, referenceColumnName := p.rescan() - - p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) - - if !p.expectNext(RPAREN) { - return "", p.unexpectedToken(RPAREN) - } - - return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil -} - func (p *Parser) expectSequence(token ...Token) bool { for _, tok := range token { if !p.expectNext(tok) {