Skip to content

Commit e33c4ca

Browse files
authored
👍 add update parser (#12)
* add test * add parser
1 parent 9e91aa0 commit e33c4ca

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

internal/parser/ast/ast.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ type SelectStatement struct {
3030
Offset *OffsetStatement
3131
}
3232

33+
type UpdateStatement struct {
34+
Table string
35+
Set []SetStatement
36+
Where *WhereStatement
37+
}
38+
39+
type SetStatement struct {
40+
Column string
41+
Value Expression
42+
}
43+
3344
// ResultStatement node represents a returning expression in a SELECT statement.
3445
type ResultStatement struct {
3546
Expr Expression
@@ -85,6 +96,8 @@ func (s *LimitStatement) statementNode() {}
8596
func (s *OffsetStatement) statementNode() {}
8697
func (s *InsertStatement) statementNode() {}
8798
func (s *CreateDatabaseStatement) statementNode() {}
99+
func (s *UpdateStatement) statementNode() {}
100+
func (s *SetStatement) statementNode() {}
88101

89102
// IdentExpr node represents an identifier.
90103
type IdentExpr struct {

internal/parser/parser.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func (p *Parser) parseStatement() (ast.Statement, error) {
3939
return p.parseSelectStatement()
4040
case token.INSERT:
4141
return p.parseInsertStatement()
42+
case token.UPDATE:
43+
return p.parseUpdateStatement()
4244
case token.CREATE:
4345
p.nextToken()
4446
return p.parseCreateStatement()
@@ -123,6 +125,27 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
123125
return &insert, nil
124126
}
125127

128+
func (p *Parser) parseUpdateStatement() (ast.Statement, error) {
129+
p.nextToken()
130+
131+
table, err := p.parseIdent()
132+
if err != nil {
133+
return nil, err
134+
}
135+
136+
set, err := p.parseSetStatement()
137+
if err != nil {
138+
return nil, err
139+
}
140+
141+
where, err := p.parseWhereStatement()
142+
if err != nil {
143+
return nil, err
144+
}
145+
146+
return &ast.UpdateStatement{Table: table.Name, Set: set, Where: where}, nil
147+
}
148+
126149
func (p *Parser) parseCreateStatement() (ast.Statement, error) {
127150
switch p.token.Type {
128151
case token.TABLE:
@@ -289,6 +312,46 @@ func (p *Parser) parseFromStatement() (*ast.FromStatement, error) {
289312
return &from, nil
290313
}
291314

315+
func (p *Parser) parseSetStatement() ([]ast.SetStatement, error) {
316+
if err := p.expect(token.SET); err != nil {
317+
return nil, err
318+
}
319+
320+
columns := make([]ast.SetStatement, 0)
321+
322+
for {
323+
column, err := p.parseIdent()
324+
if err != nil {
325+
return nil, err
326+
}
327+
328+
if err = p.expect(token.EQ); err != nil {
329+
return nil, err
330+
}
331+
332+
value, err := p.parsePrimaryExpr()
333+
if err != nil {
334+
return nil, err
335+
}
336+
337+
columns = append(columns, ast.SetStatement{
338+
Column: column.Name,
339+
Value: value,
340+
})
341+
342+
if p.peekToken.Type == token.EOF || p.peekToken.Type == token.WHERE {
343+
p.nextToken()
344+
break
345+
}
346+
347+
if err = p.expect(token.COMMA); err != nil {
348+
return nil, err
349+
}
350+
}
351+
352+
return columns, nil
353+
}
354+
292355
func (p *Parser) parseWhereStatement() (*ast.WhereStatement, error) {
293356
if p.token.Type != token.WHERE {
294357
return nil, nil

internal/parser/parser_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,67 @@ func TestParser_CreateTable(t *testing.T) {
489489
})
490490
}
491491
}
492+
493+
func TestParser_Update(t *testing.T) {
494+
t.Parallel()
495+
496+
tests := []struct {
497+
input string
498+
stmt ast.Statement
499+
}{
500+
{
501+
input: "UPDATE customers SET name = 'vlad', salary = 10*100 WHERE id = 1",
502+
stmt: &ast.UpdateStatement{
503+
Table: "customers",
504+
Set: []ast.SetStatement{
505+
{
506+
Column: "name",
507+
Value: &ast.ScalarExpr{
508+
Type: token.TEXT,
509+
Literal: "vlad",
510+
},
511+
},
512+
{
513+
Column: "salary",
514+
Value: &ast.ConditionExpr{
515+
Left: &ast.ScalarExpr{
516+
Type: token.INT,
517+
Literal: "10",
518+
},
519+
Operator: token.ASTERISK,
520+
Right: &ast.ScalarExpr{
521+
Type: token.INT,
522+
Literal: "100",
523+
},
524+
},
525+
},
526+
},
527+
Where: &ast.WhereStatement{
528+
Expr: &ast.ConditionExpr{
529+
Left: &ast.IdentExpr{
530+
Name: "id",
531+
},
532+
Operator: token.EQ,
533+
Right: &ast.ScalarExpr{
534+
Type: token.INT,
535+
Literal: "1",
536+
},
537+
},
538+
},
539+
},
540+
},
541+
}
542+
543+
for _, test := range tests {
544+
test := test
545+
546+
t.Run(test.input, func(t *testing.T) {
547+
t.Parallel()
548+
549+
p := New(lexer.New(test.input))
550+
stmts, err := p.Parse()
551+
assert.NoError(t, err)
552+
assert.Equal(t, test.stmt, stmts)
553+
})
554+
}
555+
}

0 commit comments

Comments
 (0)