Skip to content

Commit be6192f

Browse files
authored
fixed flockmtl parsing on DuckDB UI (dais-polymtl#182)
1 parent e26d2ed commit be6192f

File tree

10 files changed

+1274
-157
lines changed

10 files changed

+1274
-157
lines changed

src/custom_parser/query/model_parser.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void ModelParser::ParseCreateModel(Tokenizer& tokenizer, std::unique_ptr<QuerySt
116116
}
117117

118118
token = tokenizer.NextToken();
119-
if (token.type == TokenType::END_OF_FILE) {
119+
if (token.type == TokenType::END_OF_FILE || token.type == TokenType::SYMBOL || token.value == ";") {
120120
auto create_statement = std::make_unique<CreateModelStatement>();
121121
create_statement->catalog = catalog;
122122
create_statement->model_name = model_name;
@@ -143,7 +143,7 @@ void ModelParser::ParseDeleteModel(Tokenizer& tokenizer, std::unique_ptr<QuerySt
143143
auto model_name = token.value;
144144

145145
token = tokenizer.NextToken();
146-
if (token.type == TokenType::SYMBOL || token.value == ";") {
146+
if (token.type == TokenType::END_OF_FILE || token.type == TokenType::SYMBOL || token.value == ";") {
147147
auto delete_statement = std::make_unique<DeleteModelStatement>();
148148
delete_statement->model_name = model_name;
149149
statement = std::move(delete_statement);
@@ -175,7 +175,7 @@ void ModelParser::ParseUpdateModel(Tokenizer& tokenizer, std::unique_ptr<QuerySt
175175
auto catalog = value == "GLOBAL" ? "flockmtl_storage." : "";
176176

177177
token = tokenizer.NextToken();
178-
if (token.type == TokenType::SYMBOL || token.value == ";") {
178+
if (token.type == TokenType::END_OF_FILE || token.type == TokenType::SYMBOL || token.value == ";") {
179179
auto update_statement = std::make_unique<UpdateModelScopeStatement>();
180180
update_statement->model_name = model_name;
181181
update_statement->catalog = catalog;
@@ -255,7 +255,7 @@ void ModelParser::ParseUpdateModel(Tokenizer& tokenizer, std::unique_ptr<QuerySt
255255
}
256256

257257
token = tokenizer.NextToken();
258-
if (token.type == TokenType::END_OF_FILE) {
258+
if (token.type == TokenType::END_OF_FILE || token.type == TokenType::SYMBOL || token.value == ";") {
259259
auto update_statement = std::make_unique<UpdateModelStatement>();
260260
update_statement->new_model = new_model;
261261
update_statement->model_name = model_name;
@@ -277,7 +277,7 @@ void ModelParser::ParseGetModel(Tokenizer& tokenizer, std::unique_ptr<QueryState
277277
}
278278

279279
token = tokenizer.NextToken();
280-
if (token.type == TokenType::SYMBOL || token.value == ";") {
280+
if ((token.type == TokenType::END_OF_FILE || token.type == TokenType::SYMBOL || token.value == ";") && value == "MODELS") {
281281
auto get_all_statement = std::make_unique<GetAllModelStatement>();
282282
statement = std::move(get_all_statement);
283283
} else {
@@ -287,7 +287,7 @@ void ModelParser::ParseGetModel(Tokenizer& tokenizer, std::unique_ptr<QueryState
287287
auto model_name = token.value;
288288

289289
token = tokenizer.NextToken();
290-
if (token.type == TokenType::SYMBOL || token.value == ";") {
290+
if (token.type == TokenType::END_OF_FILE || token.type == TokenType::SYMBOL || token.value == ";") {
291291
auto get_statement = std::make_unique<GetModelStatement>();
292292
get_statement->model_name = model_name;
293293
statement = std::move(get_statement);

src/custom_parser/query/prompt_parser.cpp

Lines changed: 123 additions & 123 deletions
Large diffs are not rendered by default.

src/custom_parser/tokenizer.cpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,31 +92,47 @@ Token Tokenizer::ParseParenthesis() {
9292
return {TokenType::PARENTHESIS, std::string(1, ch)};
9393
}
9494

95+
// Parse a comment (starts with -- and goes to end of line)
96+
Token Tokenizer::ParseComment() {
97+
auto start = _position;
98+
// Skip initial --
99+
_position += 2;
100+
while (_position < static_cast<int>(_query.size()) && _query[_position] != '\n') {
101+
++_position;
102+
}
103+
auto value = _query.substr(start, _position - start);
104+
return {TokenType::COMMENT, value};
105+
}
106+
95107
// Get the next token from the input
96108
Token Tokenizer::GetNextToken() {
97109
SkipWhitespace();
98-
if (_position >= static_cast<int>(_query.size())) {
99-
return {TokenType::END_OF_FILE, ""};
100-
}
101-
102-
auto ch = _query[_position];
103-
if (ch == '\'') {
104-
return ParseStringLiteral();
105-
} else if (ch == '{') {
106-
return ParseJson();
107-
} else if (std::isalpha(ch)) {
108-
return ParseKeyword();
109-
} else if (ch == ';' || ch == ',') {
110-
return ParseSymbol();
111-
} else if (ch == '=') {
112-
return ParseSymbol();
113-
} else if (ch == '(' || ch == ')') {
114-
return ParseParenthesis();
115-
} else if (std::isdigit(ch)) {
116-
return ParseNumber();
117-
} else {
118-
return {TokenType::UNKNOWN, std::string(1, ch)};
110+
while (_position < static_cast<int>(_query.size())) {
111+
auto ch = _query[_position];
112+
if (ch == '-' && _query[_position + 1] == '-') {
113+
ParseComment();// Ignore comment
114+
SkipWhitespace();
115+
continue;
116+
}
117+
if (ch == '\'') {
118+
return ParseStringLiteral();
119+
} else if (ch == '{') {
120+
return ParseJson();
121+
} else if (std::isalpha(ch)) {
122+
return ParseKeyword();
123+
} else if (ch == ';' || ch == ',') {
124+
return ParseSymbol();
125+
} else if (ch == '=') {
126+
return ParseSymbol();
127+
} else if (ch == '(' || ch == ')') {
128+
return ParseParenthesis();
129+
} else if (std::isdigit(ch)) {
130+
return ParseNumber();
131+
} else {
132+
return {TokenType::UNKNOWN, std::string(1, ch)};
133+
}
119134
}
135+
return {TokenType::END_OF_FILE, ""};
120136
}
121137

122138
// Get the next token
@@ -137,6 +153,8 @@ std::string TokenTypeToString(TokenType type) {
137153
return "NUMBER";
138154
case TokenType::PARENTHESIS:
139155
return "PARENTHESIS";
156+
case TokenType::COMMENT:
157+
return "COMMENT";
140158
case TokenType::END_OF_FILE:
141159
return "END_OF_FILE";
142160
case TokenType::UNKNOWN:

src/include/flockmtl/custom_parser/tokenizer.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ enum class TokenType { KEYWORD,
1111
SYMBOL,
1212
NUMBER,
1313
PARENTHESIS,
14+
COMMENT,
1415
END_OF_FILE,
1516
UNKNOWN };
1617

@@ -33,6 +34,7 @@ class Tokenizer {
3334
Token ParseSymbol();
3435
Token ParseNumber();
3536
Token ParseParenthesis();
37+
Token ParseComment();
3638
Token ParseJson();
3739
Token GetNextToken();
3840

test/integration/src/integration/tests/model_parser/test_model_parser.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,7 @@ def test_create_model_invalid_syntax(integration_setup):
158158
def test_create_model_invalid_json_args(integration_setup):
159159
duckdb_cli_path, db_path = integration_setup
160160
# Invalid JSON format
161-
invalid_query1 = (
162-
"CREATE MODEL('test-model', 'gpt-4o', 'openai', '{invalid json}');"
163-
)
161+
invalid_query1 = "CREATE MODEL('test-model', 'gpt-4o', 'openai', '{invalid json}');"
164162
result1 = run_cli(duckdb_cli_path, db_path, invalid_query1)
165163
assert result1.returncode != 0
166164

@@ -262,3 +260,54 @@ def test_multiple_providers(integration_setup):
262260
assert "openai" in result.stdout
263261
assert "azure" in result.stdout
264262
assert "ollama" in result.stdout
263+
264+
265+
# Comment and Semicolon Tests
266+
def test_create_model_without_semicolon(integration_setup):
267+
duckdb_cli_path, db_path = integration_setup
268+
create_query = "CREATE MODEL('no-semicolon-model', 'gpt-4o', 'openai')"
269+
run_cli(duckdb_cli_path, db_path, create_query)
270+
get_query = "GET MODEL 'no-semicolon-model';"
271+
result = run_cli(duckdb_cli_path, db_path, get_query)
272+
assert "no-semicolon-model" in result.stdout
273+
274+
275+
def test_create_model_with_comment(integration_setup):
276+
duckdb_cli_path, db_path = integration_setup
277+
create_query = (
278+
"CREATE MODEL('comment-model', 'gpt-4o', 'openai'); -- This is a comment"
279+
)
280+
run_cli(duckdb_cli_path, db_path, create_query)
281+
get_query = "GET MODEL 'comment-model';"
282+
result = run_cli(duckdb_cli_path, db_path, get_query)
283+
assert "comment-model" in result.stdout
284+
285+
286+
def test_create_model_with_comment_before(integration_setup):
287+
duckdb_cli_path, db_path = integration_setup
288+
create_query = """-- Create a test model
289+
CREATE MODEL('comment-before-model', 'gpt-4o', 'openai');"""
290+
run_cli(duckdb_cli_path, db_path, create_query)
291+
get_query = "GET MODEL 'comment-before-model';"
292+
result = run_cli(duckdb_cli_path, db_path, get_query)
293+
assert "comment-before-model" in result.stdout
294+
295+
296+
def test_delete_model_without_semicolon(integration_setup):
297+
duckdb_cli_path, db_path = integration_setup
298+
create_query = "CREATE MODEL('delete-no-semi', 'gpt-4o', 'openai');"
299+
run_cli(duckdb_cli_path, db_path, create_query)
300+
delete_query = "DELETE MODEL 'delete-no-semi'"
301+
run_cli(duckdb_cli_path, db_path, delete_query)
302+
get_query = "GET MODEL 'delete-no-semi';"
303+
result = run_cli(duckdb_cli_path, db_path, get_query)
304+
assert "delete-no-semi" not in result.stdout or result.stdout.strip() == ""
305+
306+
307+
def test_get_models_without_semicolon(integration_setup):
308+
duckdb_cli_path, db_path = integration_setup
309+
create_query = "CREATE MODEL('get-no-semi', 'gpt-4o', 'openai');"
310+
run_cli(duckdb_cli_path, db_path, create_query)
311+
get_query = "GET MODELS"
312+
result = run_cli(duckdb_cli_path, db_path, get_query)
313+
assert "get-no-semi" in result.stdout

test/integration/src/integration/tests/prompt_parser/test_prompt_parser.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,54 @@ def test_empty_prompt_content_error(integration_setup):
157157
invalid_query = "CREATE PROMPT('test', '');"
158158
result = run_cli(duckdb_cli_path, db_path, invalid_query)
159159
assert result.returncode != 0
160+
161+
162+
# Comment and Semicolon Tests
163+
def test_create_prompt_without_semicolon(integration_setup):
164+
duckdb_cli_path, db_path = integration_setup
165+
create_query = "CREATE PROMPT('no-semi-prompt', 'Test content')"
166+
run_cli(duckdb_cli_path, db_path, create_query)
167+
get_query = "GET PROMPT 'no-semi-prompt';"
168+
result = run_cli(duckdb_cli_path, db_path, get_query)
169+
assert "no-semi-prompt" in result.stdout
170+
171+
172+
def test_create_prompt_with_comment(integration_setup):
173+
duckdb_cli_path, db_path = integration_setup
174+
create_query = (
175+
"CREATE PROMPT('comment-prompt', 'Test content'); -- This is a comment"
176+
)
177+
run_cli(duckdb_cli_path, db_path, create_query)
178+
get_query = "GET PROMPT 'comment-prompt';"
179+
result = run_cli(duckdb_cli_path, db_path, get_query)
180+
assert "comment-prompt" in result.stdout
181+
182+
183+
def test_create_prompt_with_comment_before(integration_setup):
184+
duckdb_cli_path, db_path = integration_setup
185+
create_query = """-- Create a test prompt
186+
CREATE PROMPT('comment-before-prompt', 'Test content');"""
187+
run_cli(duckdb_cli_path, db_path, create_query)
188+
get_query = "GET PROMPT 'comment-before-prompt';"
189+
result = run_cli(duckdb_cli_path, db_path, get_query)
190+
assert "comment-before-prompt" in result.stdout
191+
192+
193+
def test_delete_prompt_without_semicolon(integration_setup):
194+
duckdb_cli_path, db_path = integration_setup
195+
create_query = "CREATE PROMPT('delete-no-semi', 'Test content');"
196+
run_cli(duckdb_cli_path, db_path, create_query)
197+
delete_query = "DELETE PROMPT 'delete-no-semi'"
198+
run_cli(duckdb_cli_path, db_path, delete_query)
199+
get_query = "GET PROMPT 'delete-no-semi';"
200+
result = run_cli(duckdb_cli_path, db_path, get_query)
201+
assert "delete-no-semi" not in result.stdout or result.stdout.strip() == ""
202+
203+
204+
def test_get_prompts_without_semicolon(integration_setup):
205+
duckdb_cli_path, db_path = integration_setup
206+
create_query = "CREATE PROMPT('get-no-semi', 'Test content');"
207+
run_cli(duckdb_cli_path, db_path, create_query)
208+
get_query = "GET PROMPTS"
209+
result = run_cli(duckdb_cli_path, db_path, get_query)
210+
assert "get-no-semi" in result.stdout

0 commit comments

Comments
 (0)