cleanup and error printing

This commit is contained in:
Pablu
2025-10-28 12:02:25 +01:00
parent 15f0d190dc
commit cd5f72bfbe
4 changed files with 68 additions and 45 deletions

View File

@@ -11,12 +11,11 @@ import (
) )
func main() { func main() {
s := ` s := `CREATE TABLE TEST(
CREATE TABLE TEST( ID text PRIMARY KEY
ID text PRIMARY KEY );
);
CREATE TABLE sessions ( CREATE TABLE sessions (
session_id text PRIMARY KEY, session_id text PRIMARY KEY,
access_token text NOT NULL, access_token text NOT NULL,
user_email text NOT NULL, user_email text NOT NULL,

View File

@@ -28,6 +28,8 @@ const (
REFERENCES REFERENCES
KEY KEY
NOT NOT
IF
EXISTS
TEXT TEXT
INTEGER INTEGER

View File

@@ -1,7 +1,6 @@
package sql package sql
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"slices" "slices"
@@ -23,17 +22,13 @@ func NewParser(r io.Reader) *Parser {
func (p *Parser) Parse() (*CreateTableStatement, error) { func (p *Parser) Parse() (*CreateTableStatement, error) {
tok, ok := p.expectOne(CREATE, EOF) tok, ok := p.expectOne(CREATE, EOF)
if !ok { if !ok {
return nil, p.unexpectedToken() return nil, p.unexpectedToken(CREATE, EOF)
} else if tok == EOF { } else if tok == EOF {
return nil, io.EOF return nil, io.EOF
} }
if !p.expectNext(TABLE) { if !p.expectSequence(TABLE, IDENT) {
return nil, errors.New("Expect TABLE token") return nil, p.unexpectedToken()
}
if !p.expectNext(IDENT) {
return nil, errors.New("Expect IDENT token")
} }
stmt := CreateTableStatement{ stmt := CreateTableStatement{
@@ -41,9 +36,9 @@ func (p *Parser) Parse() (*CreateTableStatement, error) {
Columns: make([]Column, 0), Columns: make([]Column, 0),
} }
_, tok, _ = p.scan()
if tok != LPAREN { if !p.expectNext(LPAREN) {
return nil, errors.New("Expect LPAREN token") return nil, p.unexpectedToken(LPAREN)
} }
for { for {
@@ -60,35 +55,24 @@ func (p *Parser) Parse() (*CreateTableStatement, error) {
case RPAREN: case RPAREN:
if !p.expectNext(SEMI) { if !p.expectNext(SEMI) {
return nil, p.unexpectedToken() return nil, p.unexpectedToken(SEMI)
} }
return &stmt, nil return &stmt, nil
case SEMI: case SEMI:
if lastTok != RPAREN { if lastTok != RPAREN {
return nil, p.unexpectedToken() return nil, p.unexpectedToken(RPAREN)
} }
return &stmt, nil return &stmt, nil
case FOREIGN: case FOREIGN:
if !p.expectNext(KEY) { if !p.expectSequence(KEY, LPAREN, IDENT) {
return nil, p.unexpectedToken()
}
if !p.expectNext(LPAREN) {
return nil, p.unexpectedToken()
}
if !p.expectNext(IDENT) {
return nil, p.unexpectedToken() return nil, p.unexpectedToken()
} }
columnName := p.last.lit columnName := p.last.lit
if !p.expectNext(RPAREN) { if !p.expectSequence(RPAREN, REFERENCES) {
return nil, p.unexpectedToken()
}
if !p.expectNext(REFERENCES) {
return nil, p.unexpectedToken() return nil, p.unexpectedToken()
} }
@@ -104,7 +88,7 @@ func (p *Parser) Parse() (*CreateTableStatement, error) {
stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref)
default: default:
return nil, p.unexpectedToken() return nil, p.unexpectedToken(IDENT, RPAREN, SEMI, FOREIGN)
} }
} }
} }
@@ -119,7 +103,7 @@ func (p *Parser) parseColumn() (Column, error) {
case INTEGER: case INTEGER:
column.Type = lit column.Type = lit
default: default:
return Column{}, p.unexpectedToken() return Column{}, p.unexpectedToken(TEXT, INTEGER)
} }
for { for {
@@ -134,7 +118,7 @@ func (p *Parser) parseColumn() (Column, error) {
fallthrough fallthrough
case NOT: case NOT:
if _, ok := p.expectOne(NULL, KEY); !ok { if _, ok := p.expectOne(NULL, KEY); !ok {
return Column{}, p.unexpectedToken() return Column{}, p.unexpectedToken(NULL, KEY)
} }
column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, p.last.lit)) column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, p.last.lit))
@@ -146,38 +130,63 @@ func (p *Parser) parseColumn() (Column, error) {
column.Extra = append(column.Extra, ref) column.Extra = append(column.Extra, ref)
default: default:
return Column{}, p.unexpectedToken() return Column{}, p.unexpectedToken(COMMA, RPAREN, PRIMARY, NOT, REFERENCES)
} }
} }
} }
func (p *Parser) unexpectedToken() error { func (p *Parser) unexpectedToken(expected ...Token) error {
return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos) l := len(expected)
if l <= 0 {
return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos)
} else if l == 1 {
return fmt.Errorf(
"Encountered unexpected token: %v lit: '%v' on pos: %v, expected %v",
p.last.tok,
p.last.lit,
p.last.pos,
expected[0],
)
} else {
return fmt.Errorf(
"Encountered unexpected token: %v lit: '%v' on pos: %v, expected one of '%v'",
p.last.tok,
p.last.lit,
p.last.pos,
arrayToString(expected, ", "),
)
}
} }
func (p *Parser) references() (string, error) { func (p *Parser) references() (string, error) {
if !p.expectNext(IDENT) { if !p.expectNext(IDENT) {
return "", p.unexpectedToken() return "", p.unexpectedToken(IDENT)
} }
referenceTableName := p.last.lit referenceTableName := p.last.lit
if !p.expectNext(LPAREN) { if !p.expectSequence(LPAREN, IDENT) {
return "", p.unexpectedToken()
}
if !p.expectNext(IDENT) {
return "", p.unexpectedToken() return "", p.unexpectedToken()
} }
referenceColumnName := p.last.lit referenceColumnName := p.last.lit
if !p.expectNext(RPAREN) { if !p.expectNext(RPAREN) {
return "", p.unexpectedToken() return "", p.unexpectedToken(RPAREN)
} }
return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil
} }
func (p *Parser) expectSequence(token ...Token) bool {
for _, tok := range token {
if !p.expectNext(tok) {
return false
}
}
return true
}
func (p *Parser) expectNext(token Token) bool { func (p *Parser) expectNext(token Token) bool {
_, tok, _ := p.scan() _, tok, _ := p.scan()
return tok == token return tok == token

13
sql/util.go Normal file
View File

@@ -0,0 +1,13 @@
package sql
import (
"fmt"
"strings"
)
func arrayToString(a []Token, delim string) string {
return strings.Trim(strings.Replace(fmt.Sprint(a), " ", delim, -1), "[]")
//return strings.Trim(strings.Join(strings.Split(fmt.Sprint(a), " "), delim), "[]")
//return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(a)), delim), "[]")
}