From cd5f72bfbee9edaefa9da363ed0f8b8c94b65f45 Mon Sep 17 00:00:00 2001 From: Pablu Date: Tue, 28 Oct 2025 12:02:25 +0100 Subject: [PATCH] cleanup and error printing --- cmd/sqv/main.go | 11 +++---- sql/lexer.go | 2 ++ sql/parser.go | 87 +++++++++++++++++++++++++++---------------------- sql/util.go | 13 ++++++++ 4 files changed, 68 insertions(+), 45 deletions(-) create mode 100644 sql/util.go diff --git a/cmd/sqv/main.go b/cmd/sqv/main.go index f881c61..bb88be7 100644 --- a/cmd/sqv/main.go +++ b/cmd/sqv/main.go @@ -11,12 +11,11 @@ import ( ) func main() { - s := ` - CREATE TABLE TEST( - ID text PRIMARY KEY - ); + s := `CREATE TABLE TEST( + ID text PRIMARY KEY +); - CREATE TABLE sessions ( +CREATE TABLE sessions ( session_id text PRIMARY KEY, access_token text NOT NULL, user_email text NOT NULL, @@ -62,7 +61,7 @@ CREATE TABLE IF NOT EXISTS auth_states ( code_verifier text NOT NULL ); ` - + parser := sql.NewParser(strings.NewReader(s)) for { diff --git a/sql/lexer.go b/sql/lexer.go index f1d9115..e694343 100644 --- a/sql/lexer.go +++ b/sql/lexer.go @@ -28,6 +28,8 @@ const ( REFERENCES KEY NOT + IF + EXISTS TEXT INTEGER diff --git a/sql/parser.go b/sql/parser.go index af157fd..1d80de0 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -1,7 +1,6 @@ package sql import ( - "errors" "fmt" "io" "slices" @@ -23,17 +22,13 @@ func NewParser(r io.Reader) *Parser { func (p *Parser) Parse() (*CreateTableStatement, error) { tok, ok := p.expectOne(CREATE, EOF) if !ok { - return nil, p.unexpectedToken() + return nil, p.unexpectedToken(CREATE, EOF) } else if tok == EOF { return nil, io.EOF } - if !p.expectNext(TABLE) { - return nil, errors.New("Expect TABLE token") - } - - if !p.expectNext(IDENT) { - return nil, errors.New("Expect IDENT token") + if !p.expectSequence(TABLE, IDENT) { + return nil, p.unexpectedToken() } stmt := CreateTableStatement{ @@ -41,9 +36,9 @@ func (p *Parser) Parse() (*CreateTableStatement, error) { Columns: make([]Column, 0), } - _, tok, _ = p.scan() - if tok != LPAREN { - return nil, errors.New("Expect LPAREN token") + + if !p.expectNext(LPAREN) { + return nil, p.unexpectedToken(LPAREN) } for { @@ -60,35 +55,24 @@ func (p *Parser) Parse() (*CreateTableStatement, error) { case RPAREN: if !p.expectNext(SEMI) { - return nil, p.unexpectedToken() + return nil, p.unexpectedToken(SEMI) } return &stmt, nil case SEMI: if lastTok != RPAREN { - return nil, p.unexpectedToken() + return nil, p.unexpectedToken(RPAREN) } return &stmt, nil case FOREIGN: - if !p.expectNext(KEY) { - return nil, p.unexpectedToken() - } - if !p.expectNext(LPAREN) { - return nil, p.unexpectedToken() - } - - if !p.expectNext(IDENT) { + if !p.expectSequence(KEY, LPAREN, IDENT) { return nil, p.unexpectedToken() } columnName := p.last.lit - if !p.expectNext(RPAREN) { - return nil, p.unexpectedToken() - } - - if !p.expectNext(REFERENCES) { + if !p.expectSequence(RPAREN, REFERENCES) { return nil, p.unexpectedToken() } @@ -104,7 +88,7 @@ func (p *Parser) Parse() (*CreateTableStatement, error) { stmt.Columns[column].Extra = append(stmt.Columns[column].Extra, ref) default: - return nil, p.unexpectedToken() + return nil, p.unexpectedToken(IDENT, RPAREN, SEMI, FOREIGN) } } } @@ -119,7 +103,7 @@ func (p *Parser) parseColumn() (Column, error) { case INTEGER: column.Type = lit default: - return Column{}, p.unexpectedToken() + return Column{}, p.unexpectedToken(TEXT, INTEGER) } for { @@ -134,7 +118,7 @@ func (p *Parser) parseColumn() (Column, error) { fallthrough case NOT: if _, ok := p.expectOne(NULL, KEY); !ok { - return Column{}, p.unexpectedToken() + return Column{}, p.unexpectedToken(NULL, KEY) } column.Extra = append(column.Extra, fmt.Sprintf("%v_%v", lit, p.last.lit)) @@ -146,38 +130,63 @@ func (p *Parser) parseColumn() (Column, error) { column.Extra = append(column.Extra, ref) default: - return Column{}, p.unexpectedToken() + return Column{}, p.unexpectedToken(COMMA, RPAREN, PRIMARY, NOT, REFERENCES) } } } -func (p *Parser) unexpectedToken() error { - return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos) +func (p *Parser) unexpectedToken(expected ...Token) error { + l := len(expected) + if l <= 0 { + return fmt.Errorf("Encountered unexpected token: %v lit: '%v' on pos: %v", p.last.tok, p.last.lit, p.last.pos) + } else if l == 1 { + return fmt.Errorf( + "Encountered unexpected token: %v lit: '%v' on pos: %v, expected %v", + p.last.tok, + p.last.lit, + p.last.pos, + expected[0], + ) + } else { + return fmt.Errorf( + "Encountered unexpected token: %v lit: '%v' on pos: %v, expected one of '%v'", + p.last.tok, + p.last.lit, + p.last.pos, + arrayToString(expected, ", "), + ) + } } func (p *Parser) references() (string, error) { if !p.expectNext(IDENT) { - return "", p.unexpectedToken() + return "", p.unexpectedToken(IDENT) } referenceTableName := p.last.lit - if !p.expectNext(LPAREN) { - return "", p.unexpectedToken() - } - - if !p.expectNext(IDENT) { + if !p.expectSequence(LPAREN, IDENT) { return "", p.unexpectedToken() } referenceColumnName := p.last.lit if !p.expectNext(RPAREN) { - return "", p.unexpectedToken() + return "", p.unexpectedToken(RPAREN) } return fmt.Sprintf("ref %v.%v", referenceTableName, referenceColumnName), nil } +func (p *Parser) expectSequence(token ...Token) bool { + for _, tok := range token { + if !p.expectNext(tok) { + return false + } + } + + return true +} + func (p *Parser) expectNext(token Token) bool { _, tok, _ := p.scan() return tok == token diff --git a/sql/util.go b/sql/util.go new file mode 100644 index 0000000..82c5b0f --- /dev/null +++ b/sql/util.go @@ -0,0 +1,13 @@ +package sql + +import ( + "fmt" + "strings" +) + + +func arrayToString(a []Token, delim string) string { + return strings.Trim(strings.Replace(fmt.Sprint(a), " ", delim, -1), "[]") + //return strings.Trim(strings.Join(strings.Split(fmt.Sprint(a), " "), delim), "[]") + //return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(a)), delim), "[]") +}