package sql import ( "fmt" "io" "slices" ) type Parser struct { s *Lexer buf LexRes } type LexRes struct { pos Position tok Token lit string avail bool } func NewParser(r io.Reader) *Parser { return &Parser{s: NewLexer(r)} } func (p *Parser) Parse() (Statement, error) { tok, ok := p.expectOne(CREATE, EOF, SELECT) if !ok { return nil, p.unexpectedToken(CREATE, EOF) } else if tok == EOF { return nil, io.EOF } switch tok { case EOF: return nil, io.EOF case CREATE: return p.parseCreateTable() case SELECT: return p.parseSelect() default: panic("SHOULD NEVER BE REACHED") } } func (p *Parser) parseSelect() (*SelectStatement, error) { tok, ok := p.expectOne(ASTERIKS, IDENT) if !ok { return nil, p.unexpectedToken(ASTERIKS, IDENT) } fields := make([]string, 1) fields[0] = "*" if tok == IDENT { _, _, n := p.rescan() fields[0] = n for { tok, ok := p.expectOne(COMMA, FROM) if !ok { return nil, p.unexpectedToken(COMMA, FROM) } if tok == FROM { p.unscan() break } if !p.expectNext(IDENT) { return nil, p.unexpectedToken(IDENT) } _, _, n := p.rescan() fields = append(fields, n) } } if !p.expectSequence(FROM, IDENT) { return nil, p.unexpectedToken() } _, _, tableName := p.rescan() if !p.consumeUntilOne(50, SEMI, EOF) { return nil, fmt.Errorf("Expected semicolon but never found after 50 tries") } return &SelectStatement{ From: tableName, Fields: fields, }, nil } func (p *Parser) parseCreateTable() (*CreateTableStatement, error) { if !p.expectNext(TABLE) { return nil, p.unexpectedToken() } tok, ok := p.expectOne(QUOTE, SINGLE_QUOTE, BACKQUOTE, IDENT, IF) if !ok { return nil, p.unexpectedToken(IDENT, IF) } switch tok { case IF: if !p.expectSequence(NOT, EXISTS) { return nil, p.unexpectedToken() } p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) fallthrough case QUOTE, SINGLE_QUOTE, BACKQUOTE: if !p.expectNext(IDENT) { return nil, p.unexpectedToken() } } _, _, lit := p.rescan() stmt := CreateTableStatement{ TableName: lit, Columns: make([]Column, 0), } p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(LPAREN) { return nil, p.unexpectedToken(LPAREN) } for { p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) _, tok, _ := p.scan() switch tok { case RPAREN: if !p.expectNext(SEMI) { return nil, p.unexpectedToken(SEMI) } return &stmt, nil case IDENT: column, err := p.parseColumn() if err != nil { return nil, err } stmt.Columns = append(stmt.Columns, column) // TODO: HANDLE AND SAVE CONSTRAINTS case CONSTRAINT: p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(IDENT) { return nil, p.unexpectedToken() } // _, _, constraintName := p.rescan() p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) case FOREIGN: if !p.expectSequence(KEY, LPAREN) { return nil, p.unexpectedToken() } p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(IDENT) { return nil, p.unexpectedToken() } _, _, columnName := p.rescan() p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectSequence(RPAREN, REFERENCES) { return nil, p.unexpectedToken() } ref, err := p.references() if err != nil { return nil, err } column := slices.IndexFunc(stmt.Columns, func(c Column) bool { return c.Name == columnName }) stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) case PRIMARY: if !p.expectSequence(KEY, LPAREN) { return nil, p.unexpectedToken() } p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(IDENT) { return nil, p.unexpectedToken() } primaryKeyNames := make([]string, 0) _, _, columnName := p.rescan() primaryKeyNames = append(primaryKeyNames, columnName) p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) for { tok, ok := p.expectOne(RPAREN, COMMA) if !ok { return nil, p.unexpectedToken() } if tok == RPAREN { break } p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(IDENT) { return nil, p.unexpectedToken() } _, _, columnName := p.rescan() p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) primaryKeyNames = append(primaryKeyNames, columnName) } for _, pkName := range primaryKeyNames { column := slices.IndexFunc(stmt.Columns, func(c Column) bool { return c.Name == pkName }) stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, "PRIMARY_KEY") } case COMMA: continue default: return nil, p.unexpectedToken(IDENT, RPAREN, FOREIGN, COMMA) } } } func (p *Parser) parseColumn() (Column, error) { _, _, lit := p.rescan() column := Column{Name: lit, Extra: make([]string, 0)} p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if _, ok := p.expectOne(TEXT, INTEGER, REAL, BLOB, NUMERIC); !ok { return Column{}, p.unexpectedToken(TEXT, INTEGER, REAL, BLOB, NUMERIC) } _, _, column.Type = p.rescan() for { _, tok, lit := p.scan() switch tok { case COMMA: return column, nil case RPAREN: p.unscan() return column, nil case PRIMARY: fallthrough case NOT: if _, ok := p.expectOne(NULL, KEY); !ok { return Column{}, p.unexpectedToken(NULL, KEY) } _, _, rlit := p.rescan() column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, rlit)) case REFERENCES: ref, err := p.references() if err != nil { return Column{}, err } column.Extra = append(column.Extra, ref) fmt.Println(ref) case AUTOINCREMENT: column.Extra = append(column.Extra, "AUTOINCREMENT") default: return Column{}, p.unexpectedToken(COMMA, RPAREN, PRIMARY, NOT, REFERENCES) } } } func (p *Parser) unexpectedToken(expected ...Token) error { l := len(expected) pos, tok, lit := p.rescan() if l <= 0 { return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", tok, lit, pos) } else if l == 1 { return fmt.Errorf( "Encountered unexpected token: %v lit: '%v' on pos: %v, expected %v", tok, lit, pos, expected[0], ) } else { return fmt.Errorf( "Encountered unexpected token: %v lit: '%v' on pos: %v, expected one of '%v'", tok, lit, pos, arrayToString(expected, ", "), ) } } func (p *Parser) references() (string, error) { p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(IDENT) { return "", p.unexpectedToken(IDENT) } _, _, referenceTableName := p.rescan() p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(LPAREN) { return "", p.unexpectedToken() } p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(IDENT) { return "", p.unexpectedToken() } _, _, referenceColumnName := p.rescan() p.consumeIfOne(QUOTE, SINGLE_QUOTE, BACKQUOTE) if !p.expectNext(RPAREN) { return "", p.unexpectedToken(RPAREN) } return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil } func (p *Parser) expectSequence(token ...Token) bool { for _, tok := range token { if !p.expectNext(tok) { return false } } return true } func (p *Parser) expectNext(token Token) bool { _, tok, _ := p.scan() return tok == token } func (p *Parser) expectOne(token ...Token) (Token, bool) { _, tok, _ := p.scan() ok := slices.ContainsFunc(token, func(t Token) bool { return tok == t }) return tok, ok } func (p *Parser) consumeUntilOne(max int, token ...Token) bool { for range max { _, tok, _ := p.scan() if slices.ContainsFunc(token, func(t Token) bool { return tok == t }) { return true } } return false } func (p *Parser) consumeUntil(token Token, max int) bool { for range max { _, tok, _ := p.scan() if tok == token { return true } } return false } func (p *Parser) consumeIfOne(token ...Token) { _, tok, _ := p.scan() if slices.ContainsFunc(token, func(t Token) bool { return tok == t }) { return } p.unscan() } func (p *Parser) consumeIf(token Token) { _, tok, _ := p.scan() if tok == token { return } p.unscan() } func (p *Parser) scan() (Position, Token, string) { if p.buf.avail { p.buf.avail = false return p.buf.pos, p.buf.tok, p.buf.lit } pos, tok, lit := p.s.Lex() p.buf.pos, p.buf.tok, p.buf.lit = pos, tok, lit return pos, tok, lit } func (p *Parser) unscan() { p.buf.avail = true } func (p *Parser) rescan() (Position, Token, string) { p.unscan() return p.scan() }