package sql import ( "errors" "fmt" "io" "slices" ) type Parser struct { s *Lexer last struct { pos Position tok Token lit string } } func NewParser(r io.Reader) *Parser { return &Parser{s: NewLexer(r)} } func (p *Parser) Parse() (*CreateTableStatement, error) { tok, ok := p.expectOne(CREATE, EOF) if !ok { return nil, p.unexpectedToken() } else if tok == EOF { return nil, io.EOF } if !p.expectNext(TABLE) { return nil, errors.New("Expect TABLE token") } if !p.expectNext(IDENT) { return nil, errors.New("Expect IDENT token") } stmt := CreateTableStatement{ TableName: p.last.lit, Columns: make([]Column, 0), } _, tok, _ = p.scan() if tok != LPAREN { return nil, errors.New("Expect LPAREN token") } for { lastTok := p.last.tok _, tok, _ := p.scan() switch tok { case IDENT: column, err := p.parseColumn() if err != nil { return nil, err } stmt.Columns = append(stmt.Columns, column) case RPAREN: if !p.expectNext(SEMI) { return nil, p.unexpectedToken() } return &stmt, nil case SEMI: if lastTok != RPAREN { return nil, p.unexpectedToken() } return &stmt, nil case FOREIGN: if !p.expectNext(KEY) { return nil, p.unexpectedToken() } if !p.expectNext(LPAREN) { return nil, p.unexpectedToken() } if !p.expectNext(IDENT) { return nil, p.unexpectedToken() } columnName := p.last.lit if !p.expectNext(RPAREN) { return nil, p.unexpectedToken() } if !p.expectNext(REFERENCES) { return nil, p.unexpectedToken() } ref, err := p.references() if err != nil { return nil, err } column := slices.IndexFunc(stmt.Columns, func(c Column) bool { return c.Name == columnName }) stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) default: return nil, p.unexpectedToken() } } } func (p *Parser) parseColumn() (Column, error) { column := Column{Name: p.last.lit, Extra: make([]string, 0)} _, tok, lit := p.scan() switch tok { case TEXT: fallthrough case INTEGER: column.Type = lit default: return Column{}, p.unexpectedToken() } for { _, tok, lit := p.scan() switch tok { case COMMA: fallthrough case RPAREN: return column, nil case PRIMARY: fallthrough case NOT: if _, ok := p.expectOne(NULL, KEY); !ok { return Column{}, p.unexpectedToken() } column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, p.last.lit)) case REFERENCES: ref, err := p.references() if err != nil { return Column{}, err } column.Extra = append(column.Extra, ref) default: return Column{}, p.unexpectedToken() } } } func (p *Parser) unexpectedToken() error { return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos) } func (p *Parser) references() (string, error) { if !p.expectNext(IDENT) { return "", p.unexpectedToken() } referenceTableName := p.last.lit if !p.expectNext(LPAREN) { return "", p.unexpectedToken() } if !p.expectNext(IDENT) { return "", p.unexpectedToken() } referenceColumnName := p.last.lit if !p.expectNext(RPAREN) { return "", p.unexpectedToken() } return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil } func (p *Parser) expectNext(token Token) bool { _, tok, _ := p.scan() return tok == token } func (p *Parser) expectOne(token ...Token) (Token, bool) { _, tok, _ := p.scan() ok := slices.ContainsFunc(token, func(t Token) bool { return tok == t }) return tok, ok } func (p *Parser) scan() (Position, Token, string) { pos, tok, lit := p.s.Lex() // fmt.Printf("Scanning next Token: %v | pos: %v | lit: %v\n", tok, pos, lit) p.last = struct { pos Position tok Token lit string }{ pos, tok, lit, } return pos, tok, lit }