209 lines
3.8 KiB
Go
209 lines
3.8 KiB
Go
package sql
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"slices"
|
|
)
|
|
|
|
type Parser struct {
|
|
s *Lexer
|
|
last struct {
|
|
pos Position
|
|
tok Token
|
|
lit string
|
|
}
|
|
}
|
|
|
|
func NewParser(r io.Reader) *Parser {
|
|
return &Parser{s: NewLexer(r)}
|
|
}
|
|
|
|
func (p *Parser) Parse() (*CreateTableStatement, error) {
|
|
tok, ok := p.expectOne(CREATE, EOF)
|
|
if !ok {
|
|
return nil, p.unexpectedToken()
|
|
} else if tok == EOF {
|
|
return nil, io.EOF
|
|
}
|
|
|
|
if !p.expectNext(TABLE) {
|
|
return nil, errors.New("Expect TABLE token")
|
|
}
|
|
|
|
if !p.expectNext(IDENT) {
|
|
return nil, errors.New("Expect IDENT token")
|
|
}
|
|
|
|
stmt := CreateTableStatement{
|
|
TableName: p.last.lit,
|
|
Columns: make([]Column, 0),
|
|
}
|
|
|
|
_, tok, _ = p.scan()
|
|
if tok != LPAREN {
|
|
return nil, errors.New("Expect LPAREN token")
|
|
}
|
|
|
|
for {
|
|
lastTok := p.last.tok
|
|
_, tok, _ := p.scan()
|
|
|
|
switch tok {
|
|
case IDENT:
|
|
column, err := p.parseColumn()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stmt.Columns = append(stmt.Columns, column)
|
|
|
|
case RPAREN:
|
|
if !p.expectNext(SEMI) {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
return &stmt, nil
|
|
|
|
case SEMI:
|
|
if lastTok != RPAREN {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
|
|
return &stmt, nil
|
|
|
|
case FOREIGN:
|
|
if !p.expectNext(KEY) {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
if !p.expectNext(LPAREN) {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
|
|
if !p.expectNext(IDENT) {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
columnName := p.last.lit
|
|
|
|
if !p.expectNext(RPAREN) {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
|
|
if !p.expectNext(REFERENCES) {
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
|
|
ref, err := p.references()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
column := slices.IndexFunc(stmt.Columns, func(c Column) bool {
|
|
return c.Name == columnName
|
|
})
|
|
|
|
stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref)
|
|
|
|
default:
|
|
return nil, p.unexpectedToken()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Parser) parseColumn() (Column, error) {
|
|
column := Column{Name: p.last.lit, Extra: make([]string, 0)}
|
|
|
|
_, tok, lit := p.scan()
|
|
switch tok {
|
|
case TEXT:
|
|
fallthrough
|
|
case INTEGER:
|
|
column.Type = lit
|
|
default:
|
|
return Column{}, p.unexpectedToken()
|
|
}
|
|
|
|
for {
|
|
_, tok, lit := p.scan()
|
|
switch tok {
|
|
case COMMA:
|
|
fallthrough
|
|
case RPAREN:
|
|
return column, nil
|
|
|
|
case PRIMARY:
|
|
fallthrough
|
|
case NOT:
|
|
if _, ok := p.expectOne(NULL, KEY); !ok {
|
|
return Column{}, p.unexpectedToken()
|
|
}
|
|
column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, p.last.lit))
|
|
|
|
case REFERENCES:
|
|
ref, err := p.references()
|
|
if err != nil {
|
|
return Column{}, err
|
|
}
|
|
column.Extra = append(column.Extra, ref)
|
|
|
|
default:
|
|
return Column{}, p.unexpectedToken()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Parser) unexpectedToken() error {
|
|
return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos)
|
|
}
|
|
|
|
func (p *Parser) references() (string, error) {
|
|
if !p.expectNext(IDENT) {
|
|
return "", p.unexpectedToken()
|
|
}
|
|
referenceTableName := p.last.lit
|
|
|
|
if !p.expectNext(LPAREN) {
|
|
return "", p.unexpectedToken()
|
|
}
|
|
|
|
if !p.expectNext(IDENT) {
|
|
return "", p.unexpectedToken()
|
|
}
|
|
|
|
referenceColumnName := p.last.lit
|
|
|
|
if !p.expectNext(RPAREN) {
|
|
return "", p.unexpectedToken()
|
|
}
|
|
|
|
return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil
|
|
}
|
|
|
|
func (p *Parser) expectNext(token Token) bool {
|
|
_, tok, _ := p.scan()
|
|
return tok == token
|
|
}
|
|
|
|
func (p *Parser) expectOne(token ...Token) (Token, bool) {
|
|
_, tok, _ := p.scan()
|
|
ok := slices.ContainsFunc(token, func(t Token) bool {
|
|
return tok == t
|
|
})
|
|
return tok, ok
|
|
}
|
|
|
|
func (p *Parser) scan() (Position, Token, string) {
|
|
pos, tok, lit := p.s.Lex()
|
|
// fmt.Printf("Scanning next Token: %v | pos: %v | lit: %v\n", tok, pos, lit)
|
|
|
|
p.last = struct {
|
|
pos Position
|
|
tok Token
|
|
lit string
|
|
}{
|
|
pos,
|
|
tok,
|
|
lit,
|
|
}
|
|
return pos, tok, lit
|
|
}
|