package sql import ( "fmt" "io" "slices" ) type Parser struct { s *Lexer last struct { pos Position tok Token lit string } } func NewParser(r io.Reader) *Parser { return &Parser{s: NewLexer(r)} } func (p *Parser) Parse() (*CreateTableStatement, error) { tok, ok := p.expectOne(CREATE, EOF) if !ok { return nil, p.unexpectedToken(CREATE, EOF) } else if tok == EOF { return nil, io.EOF } if !p.expectSequence(TABLE, IDENT) { return nil, p.unexpectedToken() } stmt := CreateTableStatement{ TableName: p.last.lit, Columns: make([]Column, 0), } if !p.expectNext(LPAREN) { return nil, p.unexpectedToken(LPAREN) } for { lastTok := p.last.tok _, tok, _ := p.scan() switch tok { case IDENT: column, err := p.parseColumn() if err != nil { return nil, err } stmt.Columns = append(stmt.Columns, column) case RPAREN: if !p.expectNext(SEMI) { return nil, p.unexpectedToken(SEMI) } return &stmt, nil case SEMI: if lastTok != RPAREN { return nil, p.unexpectedToken(RPAREN) } return &stmt, nil case FOREIGN: if !p.expectSequence(KEY, LPAREN, IDENT) { return nil, p.unexpectedToken() } columnName := p.last.lit if !p.expectSequence(RPAREN, REFERENCES) { return nil, p.unexpectedToken() } ref, err := p.references() if err != nil { return nil, err } column := slices.IndexFunc(stmt.Columns, func(c Column) bool { return c.Name == columnName }) stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) default: return nil, p.unexpectedToken(IDENT, RPAREN, SEMI, FOREIGN) } } } func (p *Parser) parseColumn() (Column, error) { column := Column{Name: p.last.lit, Extra: make([]string, 0)} _, tok, lit := p.scan() switch tok { case TEXT: fallthrough case INTEGER: column.Type = lit default: return Column{}, p.unexpectedToken(TEXT, INTEGER) } for { _, tok, lit := p.scan() switch tok { case COMMA: fallthrough case RPAREN: return column, nil case PRIMARY: fallthrough case NOT: if _, ok := p.expectOne(NULL, KEY); !ok { return Column{}, p.unexpectedToken(NULL, KEY) } column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, p.last.lit)) case REFERENCES: ref, err := p.references() if err != nil { return Column{}, err } column.Extra = append(column.Extra, ref) default: return Column{}, p.unexpectedToken(COMMA, RPAREN, PRIMARY, NOT, REFERENCES) } } } func (p *Parser) unexpectedToken(expected ...Token) error { l := len(expected) if l <= 0 { return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos) } else if l == 1 { return fmt.Errorf( "Encountered unexpected token: %v lit: '%v' on pos: %v, expected %v", p.last.tok, p.last.lit, p.last.pos, expected[0], ) } else { return fmt.Errorf( "Encountered unexpected token: %v lit: '%v' on pos: %v, expected one of '%v'", p.last.tok, p.last.lit, p.last.pos, arrayToString(expected, ", "), ) } } func (p *Parser) references() (string, error) { if !p.expectNext(IDENT) { return "", p.unexpectedToken(IDENT) } referenceTableName := p.last.lit if !p.expectSequence(LPAREN, IDENT) { return "", p.unexpectedToken() } referenceColumnName := p.last.lit if !p.expectNext(RPAREN) { return "", p.unexpectedToken(RPAREN) } return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil } func (p *Parser) expectSequence(token ...Token) bool { for _, tok := range token { if !p.expectNext(tok) { return false } } return true } func (p *Parser) expectNext(token Token) bool { _, tok, _ := p.scan() return tok == token } func (p *Parser) expectOne(token ...Token) (Token, bool) { _, tok, _ := p.scan() ok := slices.ContainsFunc(token, func(t Token) bool { return tok == t }) return tok, ok } func (p *Parser) scan() (Position, Token, string) { pos, tok, lit := p.s.Lex() // fmt.Printf("Scanning next Token: %v | pos: %v | lit: %v\n", tok, pos, lit) p.last = struct { pos Position tok Token lit string }{ pos, tok, lit, } return pos, tok, lit }