@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
345345 size_t last_sym_start = rule.size ();
346346 const char * pos = src;
347347
348- auto handle_repetitions = [&](int min_times, int max_times) {
348+ auto handle_repetitions = [&](int min_times, int max_times) {
349349
350- if (last_sym_start == rule.size ()) {
351- throw std::runtime_error (std::string (" expecting preceding item to */+/?/{ at " ) + pos);
352- }
350+ if (last_sym_start == rule.size ()) {
351+ throw std::runtime_error (std::string (" expecting preceding item to */+/?/{ at " ) + pos);
352+ }
353353
354- // apply transformation to previous symbol (last_sym_start to end) according to
355- // the following rewrite rules:
356- // S{m,n} --> S S S (m times) S'(n-m)
357- // S'(x) ::= S S'(x-1) |
358- // (... n-m definitions of these S' rules ...)
359- // S'(1) ::= S |
360- // S{m,} --> S S S (m times) S'
361- // S' ::= S S' |
362- // S* --> S{0,}
363- // --> S' ::= S S' |
364- // S+ --> S{1,}
365- // --> S S'
366- // S' ::= S S' |
367- // S? --> S{0,1}
368- // --> S'
369- // S' ::= S |
370- 371- llama_grammar_rule prev_rule (rule.begin () + last_sym_start, rule.end ());
372- if (min_times == 0 ) {
373- rule.resize (last_sym_start);
374- } else {
375- // Repeat the previous elements (min_times - 1) times
376- for (int i = 1 ; i < min_times; i++) {
377- rule.insert (rule.end (), prev_rule.begin (), prev_rule.end ());
378- }
354+ // apply transformation to previous symbol (last_sym_start to end) according to
355+ // the following rewrite rules:
356+ // S{m,n} --> S S S (m times) S'(n-m)
357+ // S'(x) ::= S S'(x-1) |
358+ // (... n-m definitions of these S' rules ...)
359+ // S'(1) ::= S |
360+ // S{m,} --> S S S (m times) S'
361+ // S' ::= S S' |
362+ // S* --> S{0,}
363+ // --> S' ::= S S' |
364+ // S+ --> S{1,}
365+ // --> S S'
366+ // S' ::= S S' |
367+ // S? --> S{0,1}
368+ // --> S'
369+ // S' ::= S |
370+ 371+ llama_grammar_rule prev_rule (rule.begin () + last_sym_start, rule.end ());
372+ if (min_times == 0 ) {
373+ rule.resize (last_sym_start);
374+ } else {
375+ // Repeat the previous elements (min_times - 1) times
376+ for (int i = 1 ; i < min_times; i++) {
377+ rule.insert (rule.end (), prev_rule.begin (), prev_rule.end ());
379378 }
379+ }
380380
381- uint32_t last_rec_rule_id = 0 ;
382- auto n_opt = max_times < 0 ? 1 : max_times - min_times;
381+ uint32_t last_rec_rule_id = 0 ;
382+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
383383
384- llama_grammar_rule rec_rule (prev_rule);
385- for (int i = 0 ; i < n_opt; i++) {
386- rec_rule.resize (prev_rule.size ());
387- uint32_t rec_rule_id = generate_symbol_id ( rule_name);
388- if (i > 0 || max_times < 0 ) {
389- rec_rule.push_back ({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
390- }
391- rec_rule.push_back ({LLAMA_GRETYPE_ALT, 0 });
392- rec_rule.push_back ({LLAMA_GRETYPE_END, 0 });
393- add_rule ( rec_rule_id, rec_rule);
394- last_rec_rule_id = rec_rule_id;
395- }
396- if (n_opt > 0 ) {
397- rule.push_back ({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
384+ llama_grammar_rule rec_rule (prev_rule);
385+ for (int i = 0 ; i < n_opt; i++) {
386+ rec_rule.resize (prev_rule.size ());
387+ uint32_t rec_rule_id = generate_symbol_id ( rule_name);
388+ if (i > 0 || max_times < 0 ) {
389+ rec_rule.push_back ({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
398390 }
399- };
391+ rec_rule.push_back ({LLAMA_GRETYPE_ALT, 0 });
392+ rec_rule.push_back ({LLAMA_GRETYPE_END, 0 });
393+ add_rule ( rec_rule_id, rec_rule);
394+ last_rec_rule_id = rec_rule_id;
395+ }
396+ if (n_opt > 0 ) {
397+ rule.push_back ({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
398+ }
399+ };
400400
401- while (*pos) {
402- if (*pos == ' "' ) { // literal string
403- pos++;
404- last_sym_start = rule.size ();
405- while (*pos != ' "' ) {
406- if (!*pos) {
407- throw std::runtime_error (" unexpected end of input" );
408- }
409- auto char_pair = parse_char (pos);
410- pos = char_pair.second ;
411- rule.push_back ({LLAMA_GRETYPE_CHAR, char_pair.first });
401+ while (*pos) {
402+ if (*pos == ' "' ) { // literal string
403+ pos++;
404+ last_sym_start = rule.size ();
405+ while (*pos != ' "' ) {
406+ if (!*pos) {
407+ throw std::runtime_error (" unexpected end of input" );
412408 }
413- pos = parse_space (pos + 1 , is_nested);
414- } else if (*pos == ' [' ) { // char range(s)
409+ auto char_pair = parse_char (pos);
410+ pos = char_pair.second ;
411+ rule.push_back ({LLAMA_GRETYPE_CHAR, char_pair.first });
412+ }
413+ pos = parse_space (pos + 1 , is_nested);
414+ } else if (*pos == ' [' ) { // char range(s)
415+ pos++;
416+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
417+ if (*pos == ' ^' ) {
415418 pos++;
416- enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
417- if (*pos == ' ^' ) {
418- pos++;
419- start_type = LLAMA_GRETYPE_CHAR_NOT;
419+ start_type = LLAMA_GRETYPE_CHAR_NOT;
420+ }
421+ last_sym_start = rule.size ();
422+ while (*pos != ' ]' ) {
423+ if (!*pos) {
424+ throw std::runtime_error (" unexpected end of input" );
420425 }
421- last_sym_start = rule.size ();
422- while (*pos != ' ]' ) {
423- if (!*pos) {
426+ auto char_pair = parse_char (pos);
427+ pos = char_pair.second ;
428+ enum llama_gretype type = last_sym_start < rule.size ()
429+ ? LLAMA_GRETYPE_CHAR_ALT
430+ : start_type;
431+ 432+ rule.push_back ({type, char_pair.first });
433+ if (pos[0 ] == ' -' && pos[1 ] != ' ]' ) {
434+ if (!pos[1 ]) {
424435 throw std::runtime_error (" unexpected end of input" );
425436 }
426- auto char_pair = parse_char (pos);
427- pos = char_pair.second ;
428- enum llama_gretype type = last_sym_start < rule.size ()
429- ? LLAMA_GRETYPE_CHAR_ALT
430- : start_type;
431- 432- rule.push_back ({type, char_pair.first });
433- if (pos[0 ] == ' -' && pos[1 ] != ' ]' ) {
434- if (!pos[1 ]) {
435- throw std::runtime_error (" unexpected end of input" );
436- }
437- auto endchar_pair = parse_char (pos + 1 );
438- pos = endchar_pair.second ;
439- rule.push_back ({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first });
440- }
441- }
442- pos = parse_space (pos + 1 , is_nested);
443- } else if (is_word_char (*pos)) { // rule reference
444- const char * name_end = parse_name (pos);
445- uint32_t ref_rule_id = get_symbol_id (pos, name_end - pos);
446- pos = parse_space (name_end, is_nested);
447- last_sym_start = rule.size ();
448- rule.push_back ({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
449- } else if (*pos == ' (' ) { // grouping
450- // parse nested alternates into synthesized rule
451- pos = parse_space (pos + 1 , true );
452- uint32_t sub_rule_id = generate_symbol_id (rule_name);
453- pos = parse_alternates (pos, rule_name, sub_rule_id, true );
454- last_sym_start = rule.size ();
455- // output reference to synthesized rule
456- rule.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
457- if (*pos != ' )' ) {
458- throw std::runtime_error (std::string (" expecting ')' at " ) + pos);
437+ auto endchar_pair = parse_char (pos + 1 );
438+ pos = endchar_pair.second ;
439+ rule.push_back ({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first });
459440 }
441+ }
442+ pos = parse_space (pos + 1 , is_nested);
443+ } else if (is_word_char (*pos)) { // rule reference
444+ const char * name_end = parse_name (pos);
445+ uint32_t ref_rule_id = get_symbol_id (pos, name_end - pos);
446+ pos = parse_space (name_end, is_nested);
447+ last_sym_start = rule.size ();
448+ rule.push_back ({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
449+ } else if (*pos == ' (' ) { // grouping
450+ // parse nested alternates into synthesized rule
451+ pos = parse_space (pos + 1 , true );
452+ uint32_t sub_rule_id = generate_symbol_id (rule_name);
453+ pos = parse_alternates (pos, rule_name, sub_rule_id, true );
454+ last_sym_start = rule.size ();
455+ // output reference to synthesized rule
456+ rule.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
457+ if (*pos != ' )' ) {
458+ throw std::runtime_error (std::string (" expecting ')' at " ) + pos);
459+ }
460+ pos = parse_space (pos + 1 , is_nested);
461+ } else if (*pos == ' .' ) { // any char
462+ last_sym_start = rule.size ();
463+ rule.push_back ({LLAMA_GRETYPE_CHAR_ANY, 0 });
464+ pos = parse_space (pos + 1 , is_nested);
465+ } else if (*pos == ' *' ) {
466+ pos = parse_space (pos + 1 , is_nested);
467+ handle_repetitions (0 , -1 );
468+ } else if (*pos == ' +' ) {
469+ pos = parse_space (pos + 1 , is_nested);
470+ handle_repetitions (1 , -1 );
471+ } else if (*pos == ' ?' ) {
472+ pos = parse_space (pos + 1 , is_nested);
473+ handle_repetitions (0 , 1 );
474+ } else if (*pos == ' {' ) {
475+ pos = parse_space (pos + 1 , is_nested);
476+ 477+ if (!is_digit_char (*pos)) {
478+ throw std::runtime_error (std::string (" expecting an int at " ) + pos);
479+ }
480+ const char * int_end = parse_int (pos);
481+ int min_times = std::stoul (std::string (pos, int_end - pos));
482+ pos = parse_space (int_end, is_nested);
483+ 484+ int max_times = -1 ;
485+ 486+ if (*pos == ' }' ) {
487+ max_times = min_times;
460488 pos = parse_space (pos + 1 , is_nested);
461- } else if (*pos == ' .' ) { // any char
462- last_sym_start = rule.size ();
463- rule.push_back ({LLAMA_GRETYPE_CHAR_ANY, 0 });
464- pos = parse_space (pos + 1 , is_nested);
465- } else if (*pos == ' *' ) {
466- pos = parse_space (pos + 1 , is_nested);
467- handle_repetitions (0 , -1 );
468- } else if (*pos == ' +' ) {
469- pos = parse_space (pos + 1 , is_nested);
470- handle_repetitions (1 , -1 );
471- } else if (*pos == ' ?' ) {
472- pos = parse_space (pos + 1 , is_nested);
473- handle_repetitions (0 , 1 );
474- } else if (*pos == ' {' ) {
489+ } else if (*pos == ' ,' ) {
475490 pos = parse_space (pos + 1 , is_nested);
476491
477- if (!is_digit_char (*pos)) {
478- throw std::runtime_error (std::string (" expecting an int at " ) + pos);
492+ if (is_digit_char (*pos)) {
493+ const char * int_end = parse_int (pos);
494+ max_times = std::stoul (std::string (pos, int_end - pos));
495+ pos = parse_space (int_end, is_nested);
479496 }
480- const char * int_end = parse_int (pos);
481- int min_times = std::stoul (std::string (pos, int_end - pos));
482- pos = parse_space (int_end, is_nested);
483- 484- int max_times = -1 ;
485- 486- if (*pos == ' }' ) {
487- max_times = min_times;
488- pos = parse_space (pos + 1 , is_nested);
489- } else if (*pos == ' ,' ) {
490- pos = parse_space (pos + 1 , is_nested);
491- 492- if (is_digit_char (*pos)) {
493- const char * int_end = parse_int (pos);
494- max_times = std::stoul (std::string (pos, int_end - pos));
495- pos = parse_space (int_end, is_nested);
496- }
497497
498- if (*pos != ' }' ) {
499- throw std::runtime_error (std::string (" expecting '}' at " ) + pos);
500- }
501- pos = parse_space (pos + 1 , is_nested);
502- } else {
503- throw std::runtime_error (std::string (" expecting ',' at " ) + pos);
498+ if (*pos != ' }' ) {
499+ throw std::runtime_error (std::string (" expecting '}' at " ) + pos);
504500 }
505- handle_repetitions (min_times, max_times );
501+ pos = parse_space (pos + 1 , is_nested );
506502 } else {
507- break ;
503+ throw std::runtime_error ( std::string ( " expecting ',' at " ) + pos) ;
508504 }
505+ handle_repetitions (min_times, max_times);
506+ } else {
507+ break ;
509508 }
510- return pos;
511509 }
510+ return pos;
511+ }
512512
513513const char * llama_grammar_parser::parse_rule (const char * src) {
514- const char * name_end = parse_name (src);
515- const char * pos = parse_space (name_end, false );
516- size_t name_len = name_end - src;
517- uint32_t rule_id = get_symbol_id (src, name_len);
518- const std::string name (src, name_len);
519- 520- if (!(pos[0 ] == ' :' && pos[1 ] == ' :' && pos[2 ] == ' =' )) {
521- throw std::runtime_error (std::string (" expecting ::= at " ) + pos);
522- }
523- pos = parse_space (pos + 3 , true );
514+ const char * name_end = parse_name (src);
515+ const char * pos = parse_space (name_end, false );
516+ size_t name_len = name_end - src;
517+ uint32_t rule_id = get_symbol_id (src, name_len);
518+ const std::string name (src, name_len);
519+ 520+ if (!(pos[0 ] == ' :' && pos[1 ] == ' :' && pos[2 ] == ' =' )) {
521+ throw std::runtime_error (std::string (" expecting ::= at " ) + pos);
522+ }
523+ pos = parse_space (pos + 3 , true );
524524
525- pos = parse_alternates (pos, name, rule_id, false );
525+ pos = parse_alternates (pos, name, rule_id, false );
526526
527- if (*pos == ' \r ' ) {
528- pos += pos[1 ] == ' \n ' ? 2 : 1 ;
529- } else if (*pos == ' \n ' ) {
530- pos++;
531- } else if (*pos) {
532- throw std::runtime_error (std::string (" expecting newline or end at " ) + pos);
533- }
534- return parse_space (pos, true );
527+ if (*pos == ' \r ' ) {
528+ pos += pos[1 ] == ' \n ' ? 2 : 1 ;
529+ } else if (*pos == ' \n ' ) {
530+ pos++;
531+ } else if (*pos) {
532+ throw std::runtime_error (std::string (" expecting newline or end at " ) + pos);
535533 }
534+ return parse_space (pos, true );
535+ }
536536
537537bool llama_grammar_parser::parse (const char * src) {
538538 try {
0 commit comments