VYPR
High severityNVD Advisory· Published Aug 25, 2025· Updated Aug 26, 2025

XGrammar affected by Denial of Service by infinite recursion grammars

CVE-2025-57809

Description

XGrammar is an open-source library for efficient, flexible, and portable structured generation. Prior to version 0.1.21, XGrammar has an infinite recursion issue in the grammar. This issue has been resolved in version 0.1.21.

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

XGrammar before v0.1.21 has an infinite recursion bug in its grammar processing, leading to a denial of service via resource exhaustion.

Vulnerability

Overview

XGrammar, an open-source library for structured generation, contains an infinite recursion vulnerability in its grammar processing logic (prior to version 0.1.21). This bug, described in the security advisory [2], arises from malformed or specially crafted grammar definitions that cause the grammar parser to enter an unending recursive loop, consuming CPU and memory resources until the process crashes or becomes unresponsive.

Exploitation

Vector

An attacker can exploit this vulnerability by supplying a crafted grammar to an XGrammar-powered application. No authentication is required; the exploit can be triggered by submitting a malicious grammar to a public-facing service that uses XGrammar to validate or generate structured output (e.g., a JSON schema validation endpoint). Public references [2] and [3] indicate that complex, deeply nested grammars (as seen in a reported segmentation fault) can also trigger resource exhaustion.

Impact

The primary impact is a denial of service (DoS). An attacker can cause the affected application to hang, crash, or become unavailable to legitimate users. Since XGrammar is integrated into major LLM inference engines like vLLM, SGLang, TensorRT-LLM, and MLC-LLM [4], a successful exploit could disrupt production systems relying on these frameworks for structured generation.

Mitigation

The vulnerability is fixed in XGrammar version 0.1.21 [1]. Users should immediately update to this patched release. There is no known workaround for unpatched versions. The issue has been publicly reported and is addressed in the official advisory [2].

AI Insight generated on May 19, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
xgrammarPyPI
< 0.1.210.1.21

Affected products

2
  • Mlc Ai/Xgrammarllm-fuzzy2 versions
    <0.1.21+ 1 more
    • (no CPE)range: <0.1.21
    • (no CPE)range: < 0.1.21

Patches

1
b943feacb5a1

[Feature] Use an earley parser to replace the current pushdown automata (#308)

https://github.com/mlc-ai/xgrammarLinzhang LiJul 2, 2025via ghsa
18 files changed · +1844 261
  • cmake/config.cmake+1 1 modified
    @@ -1,5 +1,5 @@
     set(CMAKE_BUILD_TYPE RelWithDebInfo)
     set(XGRAMMAR_BUILD_PYTHON_BINDINGS ON)
     set(XGRAMMAR_ENABLE_COVERAGE OFF)
    -set(XGRAMMAR_BUILD_CXX_TESTS OFF)
    +set(XGRAMMAR_BUILD_CXX_TESTS ON)
     set(XGRAMMAR_ENABLE_CPPTRACE OFF)
    
  • cpp/compiled_grammar_data_structure.h+12 4 modified
    @@ -15,6 +15,8 @@
     #include <utility>
     #include <vector>
     
    +#include "earley_parser.h"
    +
     // matcher_data_structure.h is included to use StackElement
     #include "persistent_stack.h"
     #include "support/dynamic_bitset.h"
    @@ -43,7 +45,7 @@ struct AdaptiveTokenMask {
         // Only store all accepted token indices. Then rejected indices = all_indices - accepted_indices
         // - uncertain_indices. This is useful when |accepted_indices| < |rejected_indices|.
         kAccepted = 0,
    -    // Only store all accepted token indices. Then accepted indices = all_indices - rejected_indices
    +    // Only store all rejected token indices. Then accepted indices = all_indices - rejected_indices
         // - uncertain_indices. This is useful when |accepted_indices| > |rejected_indices|.
         kRejected = 1,
         // Store all accepted token indices in a bitset. This is useful when both |accepted_indices| and
    @@ -70,6 +72,13 @@ struct AdaptiveTokenMask {
           const std::vector<int32_t>& uncertain_indices
       );
     
    +  AdaptiveTokenMask(
    +      size_t vocab_size,
    +      const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +      const std::vector<int32_t>& accepted_indices,
    +      const std::vector<int32_t>& uncertain_indices
    +  );
    +
       std::string Print(const TokenizerInfo& tokenizer_info) const;
     
       std::size_t MemorySize() const;
    @@ -121,9 +130,8 @@ class CompiledGrammar::Impl {
         }
       };
     
    -  /*! \brief Mapping from the stack top element to the adaptive token mask. */
    -  std::unordered_map<StackElement, AdaptiveTokenMask, StackElementHash, StackElementEqual>
    -      adaptive_token_mask_cache;
    +  /*! \brief Mapping from the parser state to the adaptive token mask. */
    +  std::unordered_map<ParserState, AdaptiveTokenMask, StateHashForCache> adaptive_token_mask_cache;
     
       Grammar GetGrammar() const { return grammar; }
     
    
  • cpp/earley_parser.cc+604 0 added
    @@ -0,0 +1,604 @@
    +/*!
    + *  Copyright (c) 2025 by Contributors
    + * \file xgrammar/earley_parser.cc
    + */
    +
    +#include "earley_parser.h"
    +
    +#include <cassert>
    +#include <cctype>
    +#include <cstdint>
    +#include <ctime>
    +#include <unordered_map>
    +#include <utility>
    +#include <vector>
    +
    +#include "fsm.h"
    +#include "grammar_data_structure.h"
    +#include "support/encoding.h"
    +#include "support/logging.h"
    +#include "xgrammar/grammar.h"
    +
    +namespace xgrammar {
    +
    +using RuleExprType = Grammar::Impl::RuleExprType;
    +
    +using RuleExpr = Grammar::Impl::RuleExpr;
    +
    +bool EarleyParser::IsCompleted() const { return is_completed_.back(); }
    +
    +void EarleyParser::PopLastStates(int32_t cnt) {
    +  if (stop_token_is_accepted_) {
    +    stop_token_is_accepted_ = false;
    +  }
    +  if (cnt >= static_cast<int32_t>(rule_id_to_completeable_states_.size())) {
    +    XGRAMMAR_LOG(FATAL) << "The number of states to be popped is larger than the size of states.";
    +  }
    +  rule_id_to_completeable_states_.erase(
    +      rule_id_to_completeable_states_.end() - cnt, rule_id_to_completeable_states_.end()
    +  );
    +  is_completed_.erase(is_completed_.end() - cnt, is_completed_.end());
    +  scanable_state_history_.PopBack(cnt);
    +}
    +
    +void EarleyParser::Complete(const ParserState& state, const RuleExpr& rule_expr) {
    +  // Check if a rule is completed.
    +  if (state.rule_start_pos == ParserState::kNoPrevInputPos) {
    +    // assert: if a root rule can achieve here, then it must be completed.
    +    XGRAMMAR_DCHECK(rule_expr.type == RuleExprType::kSequence);
    +    XGRAMMAR_DCHECK(rule_expr.size() == state.element_id);
    +    tmp_accept_stop_token_ = true;
    +    return;
    +  }
    +  // Check all the possible parent states.
    +  const auto& parent_states_map = rule_id_to_completeable_states_[state.rule_start_pos];
    +  for (auto parent_state_iter = parent_states_map.lower_bound(state.rule_id);
    +       parent_state_iter != parent_states_map.end() && parent_state_iter->first == state.rule_id;
    +       parent_state_iter++) {
    +    const auto& parent_state = parent_state_iter->second;
    +    const auto& parent_expr = grammar_->GetRuleExpr(parent_state.sequence_id);
    +    if (parent_expr.type == RuleExprType::kSequence) {
    +      // These two types can predict other new rules. We need to
    +      // to move to the next element.
    +      XGRAMMAR_DCHECK(
    +          grammar_->GetRuleExpr(parent_expr[parent_state.element_id]).type == RuleExprType::kRuleRef
    +      );
    +      Enqueue(ParserState{
    +          parent_state.rule_id,
    +          parent_state.sequence_id,
    +          parent_state.element_id + 1,
    +          parent_state.rule_start_pos,
    +          0
    +      });
    +      continue;
    +    }
    +    XGRAMMAR_DCHECK(parent_expr.type == RuleExprType::kTagDispatch);
    +    Enqueue(
    +        {parent_state.rule_id,
    +         parent_state.sequence_id,
    +         grammar_->root_tag_dispatch_fsm->GetStart(),
    +         parent_state.rule_start_pos,
    +         0}
    +    );
    +  }
    +}
    +
    +std::pair</* scanable */ bool, /* completable */ bool> EarleyParser::Predict(
    +    const ParserState& state, const RuleExpr& rule_expr
    +) {
    +  //  If the current state is the end of the rule, we do not need to predict.
    +  if (rule_expr.type == RuleExprType::kTagDispatch) {
    +    // The rule can be scanned, but can't be completed.
    +    if (!grammar_->root_tag_dispatch_fsm->IsEndState(state.element_id)) {
    +      tmp_accept_stop_token_ = true;
    +      return std::make_pair(true, false);
    +    }
    +    // A tag has is dispatched.
    +    ExpandNextRuleRefElement(state, rule_expr, nullptr);
    +    return std::make_pair(false, false);
    +  }
    +  XGRAMMAR_DCHECK(rule_expr.type == RuleExprType::kSequence);
    +  if (state.element_id == rule_expr.size()) {
    +    // The rule is completed.
    +    return std::make_pair(false, true);
    +  }
    +  const auto& element_expr = grammar_->GetRuleExpr(rule_expr[state.element_id]);
    +  if (element_expr.type == RuleExprType::kRuleRef) {
    +    ExpandNextRuleRefElement(state, rule_expr, &element_expr);
    +    return std::make_pair(false, false);
    +  }
    +  if (element_expr.type == RuleExprType::kCharacterClassStar && state.sub_element_id == 0) {
    +    Enqueue(
    +        ParserState{state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0}
    +    );
    +  }
    +  return std::make_pair(true, false);
    +}
    +
    +void EarleyParser::Scan(const ParserState& state, const uint8_t ch) {
    +  const auto& cur_rule = grammar_->GetRuleExpr(state.sequence_id);
    +  XGRAMMAR_DCHECK(
    +      state.element_id != cur_rule.size() || cur_rule.type == RuleExprType::kTagDispatch
    +  );
    +  if (cur_rule.type == RuleExprType::kSequence) {
    +    const auto& element_expr = grammar_->GetRuleExpr(cur_rule[state.element_id]);
    +    // The element is a rule reference, we do not need to scan it.
    +    switch (element_expr.type) {
    +      case (RuleExprType::kByteString): {
    +        AdvanceByteString(state, ch, element_expr);
    +        break;
    +      }
    +      case (RuleExprType::kCharacterClass): {
    +        AdvanceCharacterClass(state, ch, element_expr);
    +        break;
    +      }
    +      case (RuleExprType::kCharacterClassStar): {
    +        AdvanceCharacterClassStar(state, ch, element_expr);
    +        break;
    +      }
    +      default: {
    +        XGRAMMAR_LOG(FATAL) << "The element type is not supported! The type is: "
    +                            << int(element_expr.type);
    +      }
    +    }
    +  } else {
    +    XGRAMMAR_DCHECK(cur_rule.type == RuleExprType::kTagDispatch);
    +    AdvanceTagDispatch(state, ch, cur_rule);
    +  }
    +}
    +
    +/*!
    +  \note The workflow of Advance is as follows:
    +  1. Scan all the states in the latest states. Add all the possible states
    +  to the next states.
    +  2. If the next states are empty, then the character is not accepted.
    +  3. If the next states are not empty, then the character is accepted. Moreover,
    +  we need to complete and predict the next states.
    +
    +  \note Thus, when initializing the Earley parser, we need to add the initial state
    +  to the history_states[0], and perform prediction and completion on the initial state.
    +*/
    +bool EarleyParser::Advance(const uint8_t ch) {
    +  // Initialize the containers.
    +  XGRAMMAR_DCHECK(tmp_process_state_queue_.empty())
    +      << "The tmp_process_state_queue_ should be empty before the scan.";
    +  tmp_states_visited_in_queue_.Clear();
    +  tmp_states_to_be_added_.clear();
    +  tmp_accept_stop_token_ = false;
    +  const auto& latest_states = scanable_state_history_[scanable_state_history_.size() - 1];
    +
    +  // Scan all the scanable states.
    +  for (const auto& state : latest_states) {
    +    Scan(state, ch);
    +  }
    +
    +  // Check if the character is accepted.
    +  if (tmp_process_state_queue_.empty() && tmp_states_to_be_added_.empty()) {
    +    return false;
    +  }
    +
    +  // execute Predict and Complete for all states in the queue until empty.
    +  rule_id_to_completeable_states_.emplace_back();
    +  while (!tmp_process_state_queue_.empty()) {
    +    const auto state = tmp_process_state_queue_.front();
    +    tmp_process_state_queue_.pop();
    +    RuleExpr rule_expr = grammar_->GetRuleExpr(state.sequence_id);
    +    auto [scanable, completable] = Predict(state, rule_expr);
    +    if (completable) {
    +      Complete(state, rule_expr);
    +    } else if (scanable) {  // A completable state can be scanned.
    +      tmp_states_to_be_added_.push_back(state);
    +    }
    +  }
    +
    +  // Check if the grammar is completed, and add the scannable states to the history.
    +  is_completed_.push_back(tmp_accept_stop_token_);
    +  scanable_state_history_.PushBack(tmp_states_to_be_added_);
    +  return true;
    +}
    +
    +EarleyParser::EarleyParser(
    +    const Grammar& grammar, const ParserState& init_state, const bool need_expand
    +)
    +    : grammar_(grammar) {
    +  // Check if the initial state is valid. If invalid, then we choose the root state as default.
    +  ParserState init = init_state;
    +  if (init_state.IsInvalid()) {
    +    init = ParserState(
    +        grammar_->GetRootRuleId(),
    +        ParserState::kUnexpandedRuleStartSequenceId,
    +        0,
    +        ParserState::kNoPrevInputPos,
    +        0
    +    );
    +  } else {
    +    init = init_state;
    +  }
    +
    +  // If there is no need to expand the initial state, we only need to add it to the
    +  // scanable states history.
    +  if (!need_expand) {
    +    rule_id_to_completeable_states_.emplace_back();
    +    is_completed_.push_back(false);
    +    scanable_state_history_.PushBack({init});
    +  }
    +
    +  // Otherwise, we expand the initial state, and process the queue.
    +  PushStateAndExpand(init);
    +}
    +
    +void EarleyParser::PushStateAndExpand(const ParserState& state) {
    +  tmp_states_visited_in_queue_.Clear();
    +  tmp_accept_stop_token_ = false;
    +  tmp_states_to_be_added_.clear();
    +  rule_id_to_completeable_states_.emplace_back();
    +  if (state.IsInvalid()) {
    +    ExpandAndEnqueueUnexpandedState(ParserState{
    +        grammar_->GetRootRuleId(),
    +        ParserState::kUnexpandedRuleStartSequenceId,
    +        0,
    +        ParserState::kNoPrevInputPos,
    +        0
    +    });
    +  } else {
    +    // If the rule can't be expanded, we need to add it to the queue.
    +    if (!ExpandAndEnqueueUnexpandedState(state)) {
    +      Enqueue(state);
    +    }
    +  }
    +  while (!tmp_process_state_queue_.empty()) {
    +    const auto state = tmp_process_state_queue_.front();
    +    tmp_process_state_queue_.pop();
    +    RuleExpr rule_expr = grammar_->GetRuleExpr(state.sequence_id);
    +    auto [scanable, completable] = Predict(state, rule_expr);
    +    if (completable) {
    +      Complete(state, rule_expr);
    +    }
    +    if (scanable) {
    +      tmp_states_to_be_added_.push_back(state);
    +    }
    +  }
    +  is_completed_.push_back(tmp_accept_stop_token_);
    +  scanable_state_history_.PushBack(tmp_states_to_be_added_);
    +}
    +
    +void EarleyParser::Reset() {
    +  rule_id_to_completeable_states_.clear();
    +  scanable_state_history_.PopBack(scanable_state_history_.size());
    +  is_completed_.clear();
    +  stop_token_is_accepted_ = false;
    +  XGRAMMAR_DCHECK(tmp_process_state_queue_.empty());
    +  PushStateAndExpand(ParserState(
    +      grammar_->GetRootRuleId(),
    +      ParserState::kUnexpandedRuleStartSequenceId,
    +      0,
    +      ParserState::kNoPrevInputPos,
    +      0
    +  ));
    +}
    +
    +bool EarleyParser::ExpandAndEnqueueUnexpandedState(const ParserState& state) {
    +  if (state.sequence_id != ParserState::kUnexpandedRuleStartSequenceId) {
    +    return false;
    +  }
    +  // The rule is already expanded, and finished.
    +  auto cur_rule_id = state.rule_id;
    +  auto cur_rule_body_id = grammar_->GetRule(cur_rule_id).body_expr_id;
    +  auto cur_rule_body = grammar_->GetRuleExpr(cur_rule_body_id);
    +  // There are two types of an unexpanded rule:
    +  // 1. The rule is a tag dispatch rule.
    +  // 2. The rule is a choice, consisting of multiple sequences.
    +  if (cur_rule_body.type == RuleExprType::kTagDispatch) {
    +    Enqueue(ParserState{
    +        cur_rule_id,
    +        cur_rule_body_id,
    +        grammar_->root_tag_dispatch_fsm->GetStart(),
    +        ParserState::kNoPrevInputPos,
    +        0
    +    });
    +    return true;
    +  }
    +  XGRAMMAR_DCHECK(cur_rule_body.type == RuleExprType::kChoices);
    +  for (const auto& sequence_id : cur_rule_body) {
    +    Enqueue(ParserState{cur_rule_id, sequence_id, 0, ParserState::kNoPrevInputPos, 0});
    +  }
    +  return true;
    +}
    +
    +void EarleyParser::ExpandNextRuleRefElement(
    +    const ParserState& state, const RuleExpr& rule_expr, const RuleExpr* sub_rule_expr
    +) {
    +  // Get the reference rule id.
    +  int ref_rule_id;
    +  if (rule_expr.type == RuleExprType::kTagDispatch) {
    +    XGRAMMAR_DCHECK(grammar_->root_tag_dispatch_fsm->IsEndState(state.element_id));
    +    ref_rule_id = grammar_->tag_dispatch_end_node_to_rule_id[state.element_id];
    +  } else {
    +    XGRAMMAR_DCHECK(rule_expr.type == RuleExprType::kSequence);
    +    XGRAMMAR_DCHECK(sub_rule_expr->type == RuleExprType::kRuleRef);
    +    ref_rule_id = (*sub_rule_expr)[0];
    +  }
    +
    +  // Add the reference rule to map.
    +  if ((state.element_id != rule_expr.size() - 1) ||
    +      state.rule_start_pos == ParserState::kNoPrevInputPos) {
    +    // It's not the right recursion, or it's the root rule.
    +    auto& states_map = rule_id_to_completeable_states_.back();
    +    states_map.insert({ref_rule_id, state});
    +  } else {
    +    // If it's the right recursion, we need to add the ancestors of the parent state.
    +    auto& states_map = rule_id_to_completeable_states_.back();
    +    auto& parent_states_map = rule_id_to_completeable_states_[state.rule_start_pos];
    +    const auto& range = states_map.equal_range(ref_rule_id);
    +    const auto in_vec = [&](const ParserState& state_) {
    +      return std::find_if(range.first, range.second, [&](const auto& s) {
    +               return StateEqualForParsing()(s.second, state_);
    +             }) != range.second;
    +    };
    +    for (auto parent_state_iter = parent_states_map.lower_bound(state.rule_id);
    +         parent_state_iter != parent_states_map.end() && parent_state_iter->first == state.rule_id;
    +         parent_state_iter++) {
    +      const auto& parent_state = parent_state_iter->second;
    +      if (!in_vec(parent_state)) {
    +        states_map.insert({ref_rule_id, parent_state});
    +      }
    +    }
    +  }
    +
    +  // Check if the reference rule is already visited.
    +  if (IsStateVisitedInQueue({ref_rule_id, -1, -1, -1, -1})) {
    +    if (std::find(
    +            grammar_->allow_empty_rule_ids.begin(),
    +            grammar_->allow_empty_rule_ids.end(),
    +            ref_rule_id
    +        ) != grammar_->allow_empty_rule_ids.end()) {
    +      if (rule_expr.type == RuleExprType::kTagDispatch) {
    +        Enqueue(ParserState{
    +            state.rule_id,
    +            state.sequence_id,
    +            grammar_->root_tag_dispatch_fsm->GetStart(),
    +            state.rule_start_pos,
    +            0
    +        });
    +        tmp_accept_stop_token_ = true;
    +        return;
    +      }
    +      XGRAMMAR_DCHECK(rule_expr.type == RuleExprType::kSequence);
    +      Enqueue(ParserState{
    +          state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0
    +      });
    +    }
    +    return;
    +  }
    +
    +  // If the reference rule is not visited, we need to add it to the queue.
    +  tmp_states_visited_in_queue_.Insert({ref_rule_id, -1, -1, -1, -1});
    +  const auto& ref_rule = grammar_->GetRule(ref_rule_id);
    +  const auto& ref_rule_expr_id = ref_rule.body_expr_id;
    +  const auto& ref_rule_expr = grammar_->GetRuleExpr(ref_rule_expr_id);
    +  XGRAMMAR_DCHECK(ref_rule_expr.type == RuleExprType::kChoices);
    +  for (const auto& sequence_id : ref_rule_expr) {
    +    const auto& sequence = grammar_->GetRuleExpr(sequence_id);
    +    if (sequence.type == RuleExprType::kEmptyStr) {
    +      Enqueue(ParserState{
    +          state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0
    +      });
    +      continue;
    +    }
    +    // Assert: the state can't be repeated. Since the rule_start_pos is the current
    +    // position, and the rule can only be predicted once.
    +    tmp_process_state_queue_.push(ParserState{
    +        ref_rule_id, sequence_id, 0, int32_t(rule_id_to_completeable_states_.size()) - 1, 0
    +    });
    +  }
    +}
    +
    +void EarleyParser::AdvanceByteString(
    +    const ParserState& state, const uint8_t ch, const RuleExpr& sub_rule
    +) {
    +  XGRAMMAR_DCHECK(sub_rule.type == RuleExprType::kByteString);
    +  XGRAMMAR_DCHECK(sub_rule.size() > state.sub_element_id);
    +  if (static_cast<uint8_t>(sub_rule[state.sub_element_id]) == ch) {
    +    auto new_state = state;
    +    new_state.sub_element_id++;
    +    if (new_state.sub_element_id == sub_rule.size()) {
    +      new_state.element_id++;
    +      new_state.sub_element_id = 0;
    +      Enqueue(new_state);
    +      // Assert: In a sequence, the bytestring can't be skipped. So the state can't be repeated.
    +    } else {
    +      tmp_states_to_be_added_.push_back(new_state);
    +    }
    +  }
    +  return;
    +}
    +
    +void EarleyParser::AdvanceCharacterClass(
    +    const ParserState& state, const uint8_t ch, const RuleExpr& sub_sequence
    +) {
    +  XGRAMMAR_DCHECK(sub_sequence.type == RuleExprType::kCharacterClass)
    +      << "The element type is not supported!";
    +
    +  // The state is matching a UTF8 character.
    +  if (state.sub_element_id > 0) {
    +    if ((ch & 0xC0) == 0x80) {
    +      auto new_state = state;
    +      new_state.sub_element_id--;
    +      // Check if the UTF8 character is completed.
    +      if (new_state.sub_element_id == 0) {
    +        new_state.element_id++;
    +        Enqueue(new_state);
    +        // Assert: In a sequence, the CharacterClass can't be skipped. So the state can't be
    +        // repeated. the fllowing tmp_process_state_queue_.push(new_state) is for the same reason.
    +      } else {
    +        tmp_states_to_be_added_.push_back(new_state);
    +      }
    +    }
    +    return;
    +  }
    +  bool is_negative = static_cast<bool>(sub_sequence[0]);
    +
    +  // This trick is based on the current structure that character class
    +  // can't accept a UTF8 character, unless it has a negation.
    +  if (!isascii(ch)) {
    +    if (!is_negative) {
    +      return;
    +    }
    +    auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(ch);
    +    if (!accepted) {
    +      return;
    +    }
    +
    +    // A new UTF8 character is accepted.
    +    XGRAMMAR_DCHECK(num_bytes > 1);
    +    auto new_state = state;
    +    new_state.sub_element_id = num_bytes - 1;
    +    tmp_states_to_be_added_.push_back(new_state);
    +    return;
    +  }
    +
    +  for (int i = 1; i < sub_sequence.size(); i += 2) {
    +    if (static_cast<uint8_t>(sub_sequence[i]) <= ch &&
    +        ch <= static_cast<uint8_t>(sub_sequence[i + 1])) {
    +      if (!is_negative) {
    +        auto new_state = state;
    +        new_state.element_id++;
    +        new_state.sub_element_id = 0;
    +        Enqueue(new_state);
    +      }
    +      return;
    +    }
    +  }
    +  if (is_negative) {
    +    auto new_state = state;
    +    new_state.element_id++;
    +    new_state.sub_element_id = 0;
    +    Enqueue(new_state);
    +  }
    +}
    +
    +void EarleyParser::AdvanceCharacterClassStar(
    +    const ParserState& state, const uint8_t ch, const RuleExpr& sub_sequence
    +) {
    +  XGRAMMAR_DCHECK(sub_sequence.type == RuleExprType::kCharacterClassStar)
    +      << "The element type is not supported!";
    +
    +  // The state is matching a UTF8 character.
    +  if (state.sub_element_id > 0) {
    +    if ((ch & 0xC0) == 0x80) {
    +      auto new_state = state;
    +      new_state.sub_element_id--;
    +      // Check if the UTF8 character is completed.
    +      if (new_state.sub_element_id == 0) {
    +        Enqueue(new_state);
    +      } else {
    +        tmp_states_to_be_added_.push_back(new_state);
    +      }
    +    }
    +    return;
    +  }
    +  bool is_negative = static_cast<bool>(sub_sequence[0]);
    +
    +  // This trick is based on the current structure that character class
    +  // can't accept a UTF8 character, unless it has a negation.
    +  if (!isascii(ch)) {
    +    if (!is_negative) {
    +      return;
    +    }
    +    auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(ch);
    +    if (!accepted) {
    +      return;
    +    }
    +    // A new UTF8 character is accepted.
    +    XGRAMMAR_DCHECK(num_bytes > 1);
    +    auto new_state = state;
    +    new_state.sub_element_id = num_bytes - 1;
    +    tmp_states_to_be_added_.push_back(new_state);
    +    return;
    +  }
    +
    +  for (int i = 1; i < sub_sequence.size(); i += 2) {
    +    if (static_cast<uint8_t>(sub_sequence[i]) <= ch &&
    +        ch <= static_cast<uint8_t>(sub_sequence[i + 1])) {
    +      if (!is_negative) {
    +        Enqueue(state);
    +      }
    +      return;
    +    }
    +  }
    +  if (is_negative) {
    +    Enqueue(state);
    +  }
    +}
    +
    +void EarleyParser::AdvanceTagDispatch(
    +    const ParserState& state, const uint8_t ch, const RuleExpr& cur_sequence
    +) {
    +  auto root_tag_dispatch_fsm_optional = grammar_->root_tag_dispatch_fsm;
    +  if (!root_tag_dispatch_fsm_optional) {
    +    XGRAMMAR_LOG(FATAL) << "The grammar does not have a root tag dispatch rule; it is not built.";
    +    XGRAMMAR_UNREACHABLE();
    +  }
    +  auto root_tag_dispatch_fsm = root_tag_dispatch_fsm_optional.value();
    +  auto start_node = root_tag_dispatch_fsm.GetStart();
    +  auto next_node = root_tag_dispatch_fsm->GetNextState(state.element_id, ch);
    +  auto new_state = state;
    +  if (next_node == CompactFSM::kNoNextState) {
    +    // Case 1. The new char cannot continue to be accepted by the tag dispatch fsm.
    +    // We try to accept the new char from the start node. If accepted, we go to the target
    +    // node. If it still cannot be accepted, we stay at the start node.
    +    auto new_next_node = root_tag_dispatch_fsm->GetNextState(start_node, ch);
    +    new_state.element_id = new_next_node == CompactFSM::kNoNextState ? start_node : new_next_node;
    +    if (root_tag_dispatch_fsm.IsEndState(new_state.element_id)) {
    +      tmp_process_state_queue_.push(new_state);
    +    } else {
    +      tmp_accept_stop_token_ = true;
    +      tmp_states_to_be_added_.push_back(new_state);
    +    }
    +  } else {
    +    // Case 2. The new char can continue to be accepted by the tag dispatch fsm.
    +    // We need to update the element id to the next node.
    +    new_state.element_id = next_node;
    +    if (root_tag_dispatch_fsm.IsEndState(next_node)) {
    +      tmp_process_state_queue_.push(new_state);
    +    } else {
    +      tmp_accept_stop_token_ = true;
    +      tmp_states_to_be_added_.push_back(new_state);
    +    }
    +  }
    +}
    +
    +bool RepeatDetector::IsVisited(const ParserState& state) const {
    +  // If the size is larger than the threshold, then we use the set to check.
    +  if (size_ > transition_threshold_) {
    +    return visited_set_.find(state) != visited_set_.end();
    +  }
    +  return std::find_if(
    +             visited_vector_.begin(),
    +             visited_vector_.begin() + size_,
    +             [&state](const ParserState& s) { return StateEqualForParsing()(state, s); }
    +         ) != visited_vector_.begin() + size_;
    +}
    +
    +void RepeatDetector::Insert(const ParserState& state) {
    +  if (size_ == transition_threshold_) {
    +    for (const auto& s : visited_vector_) {
    +      visited_set_.insert(s);
    +    }
    +  }
    +  size_++;
    +  if (size_ > transition_threshold_) {
    +    visited_set_.insert(state);
    +  } else {
    +    visited_vector_[size_ - 1] = state;
    +  }
    +}
    +
    +void RepeatDetector::Clear() {
    +  if (size_ > transition_threshold_) {
    +    visited_set_.clear();
    +  }
    +  size_ = 0;
    +}
    +
    +}  // namespace xgrammar
    
  • cpp/earley_parser.h+446 0 added
    @@ -0,0 +1,446 @@
    +/*!
    + *  Copyright (c) 2025 by Contributors
    + * \file xgrammar/earley_parser.h
    + * \brief The header for the definition of the Earley parser.
    + */
    +
    +#ifndef XGRAMMAR_EARLEY_PARSER_H_
    +#define XGRAMMAR_EARLEY_PARSER_H_
    +#include <cstdint>
    +#include <map>
    +#include <ostream>
    +#include <queue>
    +#include <unordered_set>
    +#include <utility>
    +#include <vector>
    +
    +#include "grammar_data_structure.h"
    +#include "support/csr_array.h"
    +#include "support/utils.h"
    +#include "xgrammar/grammar.h"
    +
    +namespace xgrammar {
    +
    +/*!
    + * \brief The state of the Earley parser.
    + * In the implementation, a rule can only be a kchoices or a ktagdispatch.
    + * A kchoices rule must be composed of some ksequence rules, or a kemptyrule.
    + * In the ksequence, every element in the sequence must be a kbytestring, a
    + * kcharacterclass, a kcharacterclassstar, or a rule reference.
    + *
    + * -rule_id: The id of the rule.
    + * -sequence_id: The id of the sequence in the rule.
    + * -element_id: The id of the element in the sequence, or the id of the node in
    + * the tag dispatch fsm.
    + * -rule_start_pos: The id of the parent node in the Earley parser. i.e. the rule
    + * is predicted from the k-th character.
    + * -sub_element_id: The id of the sub element in the current element, i.e.:
    + *   - kbytestring: the id of the byte in the string.
    + *   - kcharacterclass: How many bytes are left to be read in the utf8 character.
    + *   - kcharacterclassstar: How many bytes are left to be read in the utf8 character.
    + */
    +struct ParserState {
    +  constexpr ParserState() = default;
    +
    +  constexpr ParserState(const ParserState&) = default;
    +
    +  ParserState& operator=(const ParserState&) = default;
    +
    +  constexpr ParserState(
    +      const int32_t& rule_id,
    +      const int32_t& sequence_id,
    +      const int32_t& element_id,
    +      const int32_t& rule_start_pos,
    +      const int32_t& sub_element_id
    +  )
    +      : rule_id(rule_id),
    +        sequence_id(sequence_id),
    +        element_id(element_id),
    +        rule_start_pos(rule_start_pos),
    +        sub_element_id(sub_element_id) {}
    +
    +  /*!
    +   * \brief A sequence_id value of kUnexpandedRuleStartSequenceId means a rule hasn't been
    +   * expanded.
    +   */
    +  static constexpr int32_t kUnexpandedRuleStartSequenceId = 128000;
    +
    +  /*!
    +   * \brief A parent_id value of kNoParent means this ParserState is the root of the parsing stack.
    +   */
    +  static constexpr int32_t kNoPrevInputPos = -1;
    +
    +  /*! \brief A sequence_id value of kInvalid means the ParserState is invalid. */
    +  static constexpr int32_t kInvalidSequenceId = -1;
    +
    +  /*! \brief The rule's id. */
    +  int32_t rule_id = -1;
    +
    +  /*! \brief Which choice in this rule is selected. */
    +  int32_t sequence_id = -1;
    +
    +  /*!
    +   * \brief Which element of the choice sequence is to be visited. When the current sequence is
    +   * a tag dispatch rule, this element id is the currently visited node.
    +   */
    +  int32_t element_id = -1;
    +
    +  /*! \brief The position of the state, i.e. from which position, the rule starts. */
    +  int32_t rule_start_pos = -1;
    +
    +  /*! \brief The id of the sub element in the current selement of the sequence. */
    +  int32_t sub_element_id = 0;
    +
    +  /*! \brief The element is invalid when sequence_id is -1. */
    +  bool IsInvalid() const { return sequence_id == -1; }
    +
    +  static ParserState GetInvalidState() { return {-1, -1, -1, -1, -1}; }
    +
    +  bool operator==(const ParserState& other) const {
    +    return rule_id == other.rule_id && sequence_id == other.sequence_id &&
    +           element_id == other.element_id && sub_element_id == other.sub_element_id;
    +  }
    +
    +  bool operator<(const ParserState& other) const {
    +    if (rule_id != other.rule_id) return rule_id < other.rule_id;
    +    if (sequence_id != other.sequence_id) return sequence_id < other.sequence_id;
    +    if (element_id != other.element_id) return element_id < other.element_id;
    +    if (rule_start_pos != other.rule_start_pos) return rule_start_pos < other.rule_start_pos;
    +    return sub_element_id < other.sub_element_id;
    +  }
    +
    +  friend std::ostream& operator<<(std::ostream& os, const ParserState& state) {
    +    os << "ParserState(rule_id=" << state.rule_id << ", sequence_id=" << state.sequence_id
    +       << ", element_id=" << state.element_id << ", rule_start_pos=" << state.rule_start_pos
    +       << ", sub_element_id=" << state.sub_element_id << ")";
    +    return os;
    +  }
    +
    +  std::string ToString() const {
    +    return "ParserState(rule_id=" + std::to_string(rule_id) +
    +           ", sequence_id=" + std::to_string(sequence_id) +
    +           ", element_id=" + std::to_string(element_id) +
    +           ", rule_start_pos=" + std::to_string(rule_start_pos) +
    +           ", sub_element_id=" + std::to_string(sub_element_id) + ")";
    +  }
    +};
    +
    +XGRAMMAR_MEMBER_ARRAY(
    +    ParserState,
    +    &ParserState::rule_id,
    +    &ParserState::sequence_id,
    +    &ParserState::element_id,
    +    &ParserState::rule_start_pos,
    +    &ParserState::sub_element_id
    +);
    +
    +/*!
    + * \brief When getting the mask of the state, we don't need to consider the rule_start_pos.
    + */
    +class StateHashForCache {
    + public:
    +  size_t operator()(const ParserState& state) const {
    +    return HashCombine(state.rule_id, state.sequence_id, state.element_id, state.sub_element_id);
    +  }
    +};
    +
    +/*!
    + * \brief When matching the state, we need to consider the rule_start_pos, since if two states
    + * don't have the same rule_start_pos, they are not the same state.
    + */
    +class StateEqualForParsing {
    + public:
    +  bool operator()(const ParserState& lhs, const ParserState& rhs) const {
    +    return lhs.rule_id == rhs.rule_id && lhs.sequence_id == rhs.sequence_id &&
    +           lhs.element_id == rhs.element_id && lhs.rule_start_pos == rhs.rule_start_pos &&
    +           lhs.sub_element_id == rhs.sub_element_id;
    +  }
    +};
    +
    +/*!
    + * \brief This class is used to hash the ParserState for parsing.
    + * If two ParserStates don't have the same rule_start_pos, they are not the same state.
    + */
    +class StateHashForParsing {
    + public:
    +  size_t operator()(const ParserState& state) const {
    +    return HashCombine(
    +        state.rule_id,
    +        state.sequence_id,
    +        state.element_id,
    +        state.rule_start_pos,
    +        state.sub_element_id
    +    );
    +  }
    +};
    +
    +/*! \brief This class is used to detect the repeated states. */
    +class RepeatDetector {
    + private:
    +  const int transition_threshold_;
    +
    +  std::vector<ParserState> visited_vector_;
    +
    +  std::unordered_set<ParserState, StateHashForParsing, StateEqualForParsing> visited_set_;
    +
    +  int size_ = 0;
    +
    + public:
    +  RepeatDetector(const int transition_threshold = 50)
    +      : transition_threshold_(transition_threshold), size_(0) {
    +    visited_vector_.resize(transition_threshold_);
    +  }
    +
    +  /*!
    +   * \brief Check if the element is visited.
    +   * \return True if visited, false otherwise.
    +   */
    +  bool IsVisited(const ParserState& state) const;
    +
    +  /*!
    +   * \brief Add the state into the visited states.
    +   * \param state The state to be added.
    +   */
    +  void Insert(const ParserState& state);
    +
    +  /*! \brief Reset the detector. */
    +  void Clear();
    +};
    +
    +class EarleyParser {
    +  /*!
    +   * \brief Here is an article about Earley Parser.
    +   * https://en.wikipedia.org/wiki/Earley_parser#Pseudocode
    +   * We divide the parser states into three categories:
    +   * - Scanable (which will be stored in scanable_state_history_).
    +   * - Predictable(If it predict a new rule successfully, then it will be stored in
    +   * rule_id_to_completeable_states).
    +   * - Completeable(which can perform a completion operation).
    +   * A state will be stored in rule_id_to_completeable_states_ if it can be completed,
    +   * and it will be stored in scanable_state_history_ if it can be scanned. Otherwise,
    +   * it will be discarded.
    +   */
    + protected:
    +  using RuleExpr = Grammar::Impl::RuleExpr;
    +
    +  /*! \brief The grammar to be parsed. */
    +  Grammar grammar_;
    +
    +  /*! \brief In this round of advancing, check if the stop token can be accepted. */
    +  bool tmp_accept_stop_token_ = false;
    +
    +  /*! \brief store when accepting i characters, if the stop token can be accepted. */
    +  std::vector<bool> is_completed_;
    +
    +  /*!
    +   * \brief rule_id_to_completeable_states[i][j] is the i pos j rule_id states. Earley
    +   * parser needs it to complete.
    +   */
    +  std::vector<std::multimap<int32_t, ParserState>> rule_id_to_completeable_states_;
    +
    +  /*!
    +   * \brief The states history. state_stack[i] is a vector storing the states after accepting the
    +   * input[i-1].
    +   */
    +  CSRArray<ParserState> scanable_state_history_;
    +
    +  /*!
    +   * \brief A temperate vector only used in Advance, used to add states in the
    +   * scanable_state_history.
    +   */
    +  std::vector<ParserState> tmp_states_to_be_added_;
    +
    +  /*! \brief It's the processing queue of the earley parser. */
    +  std::queue<ParserState> tmp_process_state_queue_;
    +
    +  /*! \brief The class is used to check if a state has been added into the queue. */
    +  RepeatDetector tmp_states_visited_in_queue_;
    +
    +  /*! \brief Check if the stop token is accepted. */
    +  bool stop_token_is_accepted_ = false;
    +
    +  /*!
    +   * \brief Check if the state has been added into the queue.
    +   * \param state The state to check.
    +   * \return True if in the vector, false otherwise.
    +   */
    +  bool IsStateVisitedInQueue(const ParserState& state) const {
    +    return tmp_states_visited_in_queue_.IsVisited(state);
    +  }
    +
    +  /*!
    +   * \brief The scanning operation of the Earley parser. Put the new states in the queue.
    +   */
    +  void Scan(const ParserState& state, const uint8_t ch);
    +
    +  /*!
    +   * \brief The completion operation of the Earley parser.
    +   * \details The reason is that if the state can't be scanned, then
    +   * add it into the next states is useless. Moreover, the end
    +   * of the grammar is used to check if the grammar is completed,
    +   * so it should be added into the next states.
    +   */
    +  void Complete(const ParserState& state, const RuleExpr& rule_expr);
    +
    +  /*!
    +   * \brief The prediction operation of the Earley parser.
    +   * \return First: If the state scanable, or the state is the end of the grammar,
    +   * then return true, otherwise return false.
    +   * \return Second: If the state is completable, then return true, otherwise return false.
    +   */
    +  std::pair<bool, bool> Predict(const ParserState& state, const RuleExpr& rule_expr);
    +
    +  /*!
    +   * \brief Handle the unexpanded rule, used for pushing initial state.
    +   * \param state The state to be handled.
    +   * \return True if the rule is unexpanded, false otherwise.
    +   */
    +  bool ExpandAndEnqueueUnexpandedState(const ParserState& state);
    +
    +  /*!
    +   * \brief Expand the rule, used for RuleRef and kTagDispatch.
    +   * \param state The state to be expanded, which is the parent state.
    +   * The type of the state is kTagDispatch or kSequence. Moreover, the
    +   * element of the sequence should be a rule reference; the node in
    +   * the kTagDispatch should be an end node.
    +   * \param rule_expr The rule expression to be expanded.
    +   * \param sub_rule_expr The sub rule expression to be expanded, especially
    +   * when the rule is a kSequence, and the sub rule is a kRuleRef.
    +   */
    +  void ExpandNextRuleRefElement(
    +      const ParserState& state, const RuleExpr& rule_expr, const RuleExpr* sub_rule_expr
    +  );
    +
    +  /*!
    +   * \brief Advance the parser to the next state, with the sub sequence is kCharacterClass.
    +   * \param state The state to be advanced.
    +   * \param ch The character to be advanced.
    +   * \param sub_sequence The sub sequence to be checked.
    +   * \return The next state, Invalid state if the character is not accepted.
    +   */
    +  void AdvanceCharacterClass(
    +      const ParserState& state, const uint8_t ch, const RuleExpr& sub_sequence
    +  );
    +
    +  /*!
    +   * \brief Advance the parser to the next state, with the sub sequence is kByteString.
    +   * \param state The state to be advanced.
    +   * \param ch The character to be advanced.
    +   * \param sub_sequence The sub sequence to be checked.
    +   * \return The next state, Invalid state if the character is not accepted.
    +   */
    +  void AdvanceByteString(const ParserState& state, const uint8_t ch, const RuleExpr& sub_sequence);
    +
    +  /*!
    +   * \brief Advance the parser to the next state, with the sub sequence is kCharacterClassStar.
    +   * \param state The state to be advanced.
    +   * \param ch The character to be advanced.
    +   * \param sub_sequence The sub sequence to be checked.
    +   * \return The next state, Invalid state if the character is not accepted.
    +   */
    +  void AdvanceCharacterClassStar(
    +      const ParserState& state, const uint8_t ch, const RuleExpr& sub_sequence
    +  );
    +
    +  /*!
    +   * \brief Advance the parser to the next state, with the sequence is kTagDispatch.
    +   * \param state The state to be advanced.
    +   * \param ch The character to be advanced.
    +   * \param cur_sequence The sequence of the current state.
    +   * \return The next state, Invalid state if the character is not accepted.
    +   */
    +  void AdvanceTagDispatch(const ParserState& state, const uint8_t ch, const RuleExpr& cur_sequence);
    +
    +  /*!
    +   * \brief Enqueue the state into the queue.
    +   * \param state The state to be enqueued.
    +   * \details The state is enqueued if it is not visited in the queue.
    +   */
    +  void Enqueue(const ParserState& state) {
    +    if (!IsStateVisitedInQueue(state)) {
    +      tmp_process_state_queue_.push(state);
    +      tmp_states_visited_in_queue_.Insert(state);
    +    }
    +  }
    +
    + public:
    +  /*!
    +   * \brief Constructor of the Earley parser.
    +   * \param grammar The grammar to be parsed.
    +   * \param initial_state The initial state to be pushed into the parser.
    +   */
    +  EarleyParser(
    +      const Grammar& grammar, const ParserState& initial_state, const bool need_expand = true
    +  );
    +
    +  /*!
    +   * \brief From the current states, advance to the next state.
    +   * \param ch The character to be advanced.
    +   * \return True if the character is accepted, false otherwise.
    +   * \note If the character isn't accepted, then the states won't be changed.
    +   */
    +  bool Advance(const uint8_t ch);
    +
    +  /*!
    +   * \brief Remove the newly added states.
    +   * \param count The number of states to be removed.
    +   */
    +  void PopLastStates(int32_t count = 1);
    +
    +  /*!
    +   * \brief Check whether any of the multiple states stored in the parser has already completed.
    +   * \note Since the parser contains multiple parallel states, some may have already completed,
    +   * while others might still be able to accept more characters.
    +   * \return True if the root rule is completed, false otherwise.
    +   */
    +  bool IsCompleted() const;
    +
    +  /*!
    +   * \brief Push the initial state into the Earley parser.
    +   * \param state The initial state to be pushed.
    +   */
    +  void PushStateAndExpand(const ParserState& state);
    +
    +  /*!
    +   * \brief Reset the parser.
    +   * \note This function is used to reset the parser, and initialize the
    +   * parser with the root rule.
    +   */
    +  void Reset();
    +
    +  /*!
    +   * \brief Get the current scanable states.
    +   * \return The scanable states.
    +   */
    +  std::vector<ParserState> GetLatestScanableStates() const {
    +    std::vector<ParserState> latest_states;
    +    for (const auto& state : scanable_state_history_[scanable_state_history_.size() - 1]) {
    +      latest_states.push_back(state);
    +    }
    +    return latest_states;
    +  }
    +
    +  /*!
    +   * \brief Push one state to check if it can accept the token.
    +   * \param state The state to be pushed.
    +   */
    +  void PushOneStateToCheck(const ParserState& state) {
    +    rule_id_to_completeable_states_.emplace_back();
    +    is_completed_.push_back(is_completed_.back());
    +    scanable_state_history_.PushBack(&state, 1);
    +    return;
    +  }
    +
    +  std::string PrintStates() const {
    +    std::string result;
    +    result += "There are " + std::to_string(scanable_state_history_.size()) + " scanable states:\n";
    +    for (const auto& state : scanable_state_history_[scanable_state_history_.size() - 1]) {
    +      result += state.ToString() + "\n";
    +    }
    +    return result;
    +  }
    +};
    +
    +}  // namespace xgrammar
    +
    +#endif  // XGRAMMAR_EARLEY_PARSER_H_
    
  • cpp/fsm.cc+1 1 modified
    @@ -285,7 +285,7 @@ CompactFSM FSM::Impl::ToCompact() {
       CSRArray<FSMEdge> edges;
       for (int i = 0; i < static_cast<int>(edges_.size()); ++i) {
         std::sort(edges_[i].begin(), edges_[i].end());
    -    edges.Insert(edges_[i]);
    +    edges.PushBack(edges_[i]);
       }
       return CompactFSM(edges);
     }
    
  • cpp/grammar_compiler.cc+395 106 modified
    @@ -5,15 +5,21 @@
     
     #include <xgrammar/compiler.h>
     
    +#include <algorithm>
    +#include <bitset>
    +#include <cctype>
     #include <cstddef>
    +#include <cstdint>
    +#include <utility>
     #include <variant>
    +#include <vector>
     
     #include "compiled_grammar_data_structure.h"
    +#include "earley_parser.h"
     #include "fsm.h"
     #include "fsm_builder.h"
     #include "grammar_data_structure.h"
     #include "grammar_functor.h"
    -#include "grammar_matcher_base.h"
     #include "support/logging.h"
     #include "support/thread_pool.h"
     #include "support/thread_safe_cache.h"
    @@ -98,6 +104,28 @@ AdaptiveTokenMask::AdaptiveTokenMask(
       this->uncertain_indices = uncertain_indices;
     }
     
    +AdaptiveTokenMask::AdaptiveTokenMask(
    +    size_t vocab_size,
    +    const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +    const std::vector<int32_t>& accepted_indices,
    +    const std::vector<int32_t>& uncertain_indices
    +) {
    +  auto size_acc = accepted_indices.size();
    +
    +  store_type = size_acc >= USE_BITSET_THRESHOLD ? StoreType::kAcceptedBitset : StoreType::kAccepted;
    +
    +  if (store_type == StoreType::kAcceptedBitset) {
    +    accepted_bitset = DynamicBitset(vocab_size);
    +    for (auto idx : accepted_indices) {
    +      accepted_bitset.Set(sorted_decoded_vocab[idx].first, true);
    +    }
    +  } else {
    +    XGRAMMAR_DCHECK(store_type == StoreType::kAccepted);
    +    this->accepted_indices = accepted_indices;
    +  }
    +  this->uncertain_indices = uncertain_indices;
    +}
    +
     std::string AdaptiveTokenMask::Print(const TokenizerInfo& tokenizer_info) const {
       constexpr int kMaxPrintTokens = 100;
       std::stringstream ss;
    @@ -201,22 +229,39 @@ TokenizerInfo CompiledGrammar::GetTokenizerInfo() const { return pimpl_->GetToke
     /************** Use GrammarMatcher to generate the AdaptiveTokenMaskCache **************/
     
     /*! \brief The concrete implementation of GrammarMatcherNode. */
    -class GrammarMatcherForTokenMaskCache : public GrammarMatcherBase {
    +class GrammarMatcherForTokenMaskCache : public EarleyParser {
      public:
    -  // Do not expand the initial stack element: we want to find the accepted/rejected tokens
    -  // that exactly start from the initial stack element.
    -  GrammarMatcherForTokenMaskCache(const Grammar& grammar, StackElement init_stack_element)
    -      : GrammarMatcherBase(grammar, init_stack_element, false),
    -        init_rule_id(init_stack_element.rule_id) {}
    -
    +  GrammarMatcherForTokenMaskCache(
    +      const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
    +  )
    +      : EarleyParser(grammar, init_state),
    +        init_rule_id(init_state.rule_id),
    +        initial_state(init_state) {}
       /*!
    -   * \brief Get the adaptive token mask for the given StackElement.
    +   * \brief Get the adaptive token mask for the given ParserState.
        * \param is_root_rule Whether to consider the parent rule. If false, there will be
        * no uncertain tokens. Useful for the root rule.
        */
       AdaptiveTokenMask GetAdaptiveTokenMask(
           size_t vocab_size,
           const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +      const std::vector<int32_t>& subtree_nodes_range,
    +      bool is_root_rule
    +  );
    +
    +  /*!
    +   * \brief Get the token mask for the given ParserState.
    +   * \param sorted_decoded_vocab The sorted decoded vocabulary.
    +   * \param first_char_mask The first character mask.
    +   * \param is_root_rule Whether to consider the parent rule. If false, there will be
    +   * no uncertain tokens. Useful for the root rule.
    +   * \returns True if the rejected indices are filled as usual, False otherwise.
    +   * It's used to determine which construction function will be used.
    +   */
    +  bool GetTokenMaskWithFirstCharacterCheck(
    +      const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +      const std::bitset<256>& first_char_mask,
    +      const std::vector<int>& subtree_nodes_range,
           bool is_root_rule
       );
     
    @@ -226,9 +271,18 @@ class GrammarMatcherForTokenMaskCache : public GrammarMatcherBase {
           const std::string& token, const std::vector<bool>& can_reach_end_stack
       );
     
    +  /*! \brief Check if speculative calculation will be applied.*/
    +  bool IsSpeculativeCalculationApplied(
    +      const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +      int possible_token_num
    +  );
    +
       // The id of the initial rule.
       int32_t init_rule_id;
     
    +  // The initial state of the parser.
    +  ParserState initial_state;
    +
       // Temporary data for GetAdaptiveTokenMask.
       std::vector<int32_t> tmp_accepted_indices_;
       std::vector<int32_t> tmp_rejected_indices_;
    @@ -244,8 +298,9 @@ bool GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
       if (lookahead_assertion_id == -1) {
         return true;
       }
    -  auto lookahead_stack_element = StackElement(-1, lookahead_assertion_id, 0);
    -  PushInitialState(lookahead_stack_element, true);
    +  auto lookahead_state =
    +      ParserState(/*rule_id*/ -1, lookahead_assertion_id, 0, ParserState::kNoPrevInputPos, 0);
    +  PushStateAndExpand(lookahead_state);
       int token_len = token.size();
     
       // Find all positions that can come to and end. Then check if the suffix from that position
    @@ -256,116 +311,346 @@ bool GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
         }
         int last_accept_pos = i - 1;
         for (int pos = i; pos < token_len; ++pos) {
    -      if (!AcceptChar(token[pos])) {
    +      if (!Advance(token[pos])) {
             break;
           }
           last_accept_pos = pos;
           // Case 1. The whole rule is finished.
    -      if (CanReachEnd()) {
    +      if (IsCompleted()) {
             // accepted chars: pos - i + 1
             // we need to rollback the pushed initial state as well
    -        RollbackChars(pos - i + 2);
    +        PopLastStates(pos - i + 2);
             return true;
           }
         }
         // Case 2. The whole token is accepted
         if (last_accept_pos == token_len - 1) {
    -      RollbackChars(last_accept_pos - i + 2);
    +      PopLastStates(last_accept_pos - i + 2);
           return true;
         }
         // Case 3. The token is not accepted. Check the next position.
    -    RollbackChars(last_accept_pos - i + 1);
    +    PopLastStates(last_accept_pos - i + 1);
       }
     
    -  RollbackChars(1);
    +  PopLastStates(1);
       return false;
     }
     
    -AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
    -    size_t vocab_size,
    +// Comparator for std::pair<int32_t, std::string> based on the string value.
    +class IntStringPairComparator {
    + public:
    +  bool operator()(
    +      const std::pair<int32_t, std::string>& lhs, const std::pair<int32_t, std::string>& rhs
    +  ) const {
    +    return lhs.second < rhs.second;
    +  }
    +};
    +
    +int GetPossibleTokenIntervals(
    +    const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +    const std::bitset<256>& first_char_mask,
    +    std::vector<std::pair<int32_t, int32_t>>& possible_intervals
    +) {
    +  int possible_token_num = 0;
    +  int matched_size = 0;
    +  int last_interval_end = -1;
    +  for (int32_t i = 0; i < 256; i++) {
    +    if (first_char_mask[i]) {
    +      if (last_interval_end == -1) {
    +        last_interval_end = i;
    +      }
    +    } else {
    +      if (last_interval_end != -1) {
    +        int32_t interval_left_end =
    +            std::lower_bound(
    +                sorted_decoded_vocab.begin() + matched_size,
    +                sorted_decoded_vocab.end(),
    +                std::make_pair(0, std::string(1, static_cast<uint8_t>(last_interval_end))),
    +                IntStringPairComparator()
    +            ) -
    +            sorted_decoded_vocab.begin();
    +        int32_t interval_right_end = std::lower_bound(
    +                                         sorted_decoded_vocab.begin() + interval_left_end,
    +                                         sorted_decoded_vocab.end(),
    +                                         std::make_pair(0, std::string(1, static_cast<uint8_t>(i))),
    +                                         IntStringPairComparator()
    +                                     ) -
    +                                     sorted_decoded_vocab.begin();
    +        possible_intervals.emplace_back(interval_left_end, interval_right_end);
    +        possible_token_num += interval_right_end - interval_left_end;
    +        last_interval_end = -1;
    +        matched_size = interval_right_end;
    +      }
    +    }
    +  }
    +
    +  if (last_interval_end != -1) {
    +    // If the last interval is not closed, we need to close it.
    +    int32_t interval_left_end =
    +        std::lower_bound(
    +            sorted_decoded_vocab.begin() + matched_size,
    +            sorted_decoded_vocab.end(),
    +            std::make_pair(0, std::string(1, static_cast<uint8_t>(last_interval_end))),
    +            IntStringPairComparator()
    +        ) -
    +        sorted_decoded_vocab.begin();
    +    possible_intervals.emplace_back(interval_left_end, sorted_decoded_vocab.size());
    +    possible_token_num += sorted_decoded_vocab.size() - interval_left_end;
    +  }
    +  return possible_token_num;
    +}
    +
    +bool GrammarMatcherForTokenMaskCache::IsSpeculativeCalculationApplied(
    +    const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab, int possible_token_num
    +) {
    +  using RuleExprType = Grammar::Impl::RuleExprType;
    +  // Check if the initial state is self-recursive-like. If the state is self-recursive-like,
    +  // and it covers a large part of the vocabulary, we will do speculative calculation in compiling.
    +  if (initial_state.sub_element_id == 0 &&
    +      possible_token_num > static_cast<int>(sorted_decoded_vocab.size() / 4)) {
    +    const auto& sequence_expr = grammar_->GetRuleExpr(initial_state.sequence_id);
    +    // A self-recursive-like rule must be a sequence.
    +    if (sequence_expr.type == RuleExprType::kSequence) {
    +      const auto& current_element_expr =
    +          grammar_->GetRuleExpr(sequence_expr[initial_state.element_id]);
    +      // If the current element is a character class star, then it's self-recursive without doubt.
    +      if (current_element_expr.type == RuleExprType::kCharacterClassStar) {
    +        return true;
    +        // If the current element is a character class, and the next element is a rule ref to
    +        // itself, and the rule only has 2 elements, then it's self-recursive-like.
    +      } else if (current_element_expr.type == RuleExprType::kCharacterClass &&
    +                 sequence_expr.size() == 2 && initial_state.element_id == 0) {
    +        const auto& end_element_expr = grammar_->GetRuleExpr(sequence_expr[1]);
    +        if (end_element_expr.type == RuleExprType::kRuleRef &&
    +            end_element_expr[0] == initial_state.rule_id) {
    +          return true;
    +        }
    +      }
    +    }
    +  }
    +  return false;
    +}
    +
    +bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
         const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +    const std::bitset<256>& first_char_mask,
    +    const std::vector<int>& subtree_nodes_range,
         bool is_root_rule
     ) {
    -  tmp_accepted_indices_.clear();
    -  tmp_rejected_indices_.clear();
    -  tmp_uncertain_indices_.clear();
    +  // the pair (a, b) means [a, b). Intialize the possible intervals.
    +  std::vector<std::pair<int32_t, int32_t>> possible_intervals;
    +  int possible_token_num =
    +      GetPossibleTokenIntervals(sorted_decoded_vocab, first_char_mask, possible_intervals);
     
    -  // For every character in the current token, stores whether it is possible to reach the end of
    -  // the rule when matching until this character. Store it in a stack for later rollback.
    -  tmp_can_reach_end_stack_.assign({CanReachEnd()});
    -  tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()});
    +  // Check if the type of the mask can be krejected.
    +  bool fill_reject_indices =
    +      (sorted_decoded_vocab.size() - possible_token_num) < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
    +
    +  XGRAMMAR_DCHECK(possible_intervals.size() > 0)
    +      << "There should be at least one possible interval for the first character mask.";
    +
    +  if (possible_intervals[0].first != 0 && fill_reject_indices) {
    +    for (int i = 0; i < possible_intervals[0].first; ++i) {
    +      tmp_rejected_indices_.push_back(i);
    +    }
    +  }
    +
    +  bool speculative_calculation =
    +      IsSpeculativeCalculationApplied(sorted_decoded_vocab, possible_token_num);
     
       int prev_matched_size = 0;
    -  for (int i = 0; i < static_cast<int>(sorted_decoded_vocab.size()); ++i) {
    -    const auto& token = sorted_decoded_vocab[i].second;
    -
    -    bool accepted = true;
    -
    -    // Many tokens may contain the same prefix, so we will avoid unnecessary matching
    -    // by finding the longest common prefix with the previous token.
    -    if (i > 0) {
    -      const auto& prev_token = sorted_decoded_vocab[i - 1].second;
    -      int lcp_len =
    -          std::mismatch(token.begin(), token.end(), prev_token.begin(), prev_token.end()).first -
    -          token.begin();
    -      if (lcp_len > prev_matched_size) {
    -        // Case 1. The common prefix is rejected by the matcher in the last token. Reject
    -        // directly.
    -        accepted = false;
    -      } else if (lcp_len < prev_matched_size) {
    -        // Case 2. The common prefix is shorter than the previous matched size. Rollback
    -        // the non-common part.
    -        RollbackChars(prev_matched_size - lcp_len);
    -        tmp_can_reach_end_stack_.erase(
    -            tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len),
    -            tmp_can_reach_end_stack_.end()
    -        );
    -        tmp_can_reach_end_prefix_or_stack_.erase(
    -            tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len),
    -            tmp_can_reach_end_prefix_or_stack_.end()
    -        );
    +  int last_rejected_range = 0;
    +  const std::string* prev_token = nullptr;
    +  for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) {
    +    const auto& interval = possible_intervals[interval_idx];
    +    for (int i = interval.first; i < interval.second; ++i) {
    +      // Check if the current token is in the rejected range. i.e. check if the current token
    +      // is on the subtree of the rejected token.
    +      if (i < last_rejected_range) {
    +        if (fill_reject_indices) {
    +          tmp_rejected_indices_.push_back(i);
    +          fill_reject_indices =
    +              tmp_rejected_indices_.size() < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
    +        } else {
    +          i = last_rejected_range - 1;
    +        }
    +        continue;
           }
    -      prev_matched_size = std::min(prev_matched_size, lcp_len);
    -    }
     
    -    if (accepted) {
    -      // Accept the rest chars one by one
    -      for (int j = prev_matched_size; j < static_cast<int>(token.size()); ++j) {
    -        if (!AcceptChar(token[j], false)) {
    +      const auto& token = sorted_decoded_vocab[i].second;
    +      // This optimization is useful for simple self-recursive rules, like string content.
    +      if (speculative_calculation) {
    +        bool all_accepted = true;
    +        for (char ch : token) {
    +          // If the first character is not the ascii character or can't be accepted by the
    +          // first character mask, we need to check them in the parser.
    +          if (isascii(ch) == 0 || !first_char_mask[static_cast<uint8_t>(ch)]) {
    +            all_accepted = false;
    +            break;
    +          }
    +        }
    +        if (all_accepted) {
    +          tmp_accepted_indices_.push_back(i);
    +          continue;
    +        }
    +      }
    +      // Many tokens may contain the same prefix, so we will avoid unnecessary matching
    +      // by finding the longest common prefix with the previous token.
    +      bool accepted = true;
    +      if (prev_token != nullptr) {
    +        int lcp_len =
    +            std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end())
    +                .first -
    +            token.begin();
    +        if (lcp_len > prev_matched_size) {
    +          // Case 1. The common prefix is rejected by the matcher in the last token. Reject
    +          // directly.
               accepted = false;
    -          break;
    +        } else if (lcp_len < prev_matched_size) {
    +          // Case 2. The common prefix is shorter than the previous matched size. Rollback
    +          // the non-common part.
    +          PopLastStates(prev_matched_size - lcp_len);
    +          tmp_can_reach_end_stack_.erase(
    +              tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len),
    +              tmp_can_reach_end_stack_.end()
    +          );
    +          tmp_can_reach_end_prefix_or_stack_.erase(
    +              tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len),
    +              tmp_can_reach_end_prefix_or_stack_.end()
    +          );
    +        }
    +        prev_matched_size = std::min(prev_matched_size, lcp_len);
    +      }
    +
    +      prev_token = &token;
    +
    +      if (accepted) {
    +        // Accept the rest chars one by one.
    +        for (int j = prev_matched_size; j < static_cast<int>(token.size()); ++j) {
    +          if (!Advance(token[j])) {
    +            accepted = false;
    +            break;
    +          }
    +          tmp_can_reach_end_stack_.push_back(IsCompleted());
    +          tmp_can_reach_end_prefix_or_stack_.push_back(
    +              tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back()
    +          );
    +          prev_matched_size = j + 1;
             }
    -        tmp_can_reach_end_stack_.push_back(CanReachEnd());
    -        tmp_can_reach_end_prefix_or_stack_.push_back(
    -            tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back()
    -        );
    -        prev_matched_size = j + 1;
    +      }
    +
    +      bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
    +
    +      if (accepted) {
    +        tmp_accepted_indices_.push_back(i);
    +      } else if (can_reach_end && !is_root_rule &&
    +                 IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_) &&
    +                 prev_matched_size > 0) {
    +        // 1. If the current rule is the root rule (is_root_rule=true), there are no
    +        // uncertain tokens. Not accepted tokens are just rejected.
    +        // 2. If a token cannot pass the lookahead assertion, it is rejected.
    +        tmp_uncertain_indices_.push_back(i);
    +      } else {
    +        tmp_rejected_indices_.push_back(i);
    +        last_rejected_range = subtree_nodes_range[i];
    +        fill_reject_indices =
    +            tmp_rejected_indices_.size() < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
           }
         }
    +    if (interval_idx != possible_intervals.size() - 1 && fill_reject_indices) {
    +      const auto& next_interval = possible_intervals[interval_idx + 1];
    +      for (int i = interval.second; i < next_interval.first; ++i) {
    +        tmp_rejected_indices_.push_back(i);
    +      }
    +      fill_reject_indices = tmp_rejected_indices_.size() < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
    +    }
    +  }
     
    -    bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
    +  // Rollback the last matched part.
    +  PopLastStates(prev_matched_size);
     
    -    if (accepted) {
    -      tmp_accepted_indices_.push_back(i);
    -    } else if (can_reach_end && !is_root_rule &&
    -               IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_)) {
    -      // 1. If the current rule is the root rule (is_root_rule=true), there are no
    -      // uncertain tokens. Not accepted tokens are just rejected.
    -      // 2. If a token cannot pass the lookahead assertion, it is rejected.
    -      tmp_uncertain_indices_.push_back(i);
    -    } else {
    +  if (possible_intervals.back().second != static_cast<int>(sorted_decoded_vocab.size()) &&
    +      fill_reject_indices) {
    +    // If the last interval is not closed, we need to reject the rest tokens.
    +    for (int i = possible_intervals.back().second;
    +         i < static_cast<int>(sorted_decoded_vocab.size());
    +         ++i) {
           tmp_rejected_indices_.push_back(i);
         }
       }
    -  // Rollback the last matched part
    -  RollbackChars(prev_matched_size);
    -  return AdaptiveTokenMask(
    -      vocab_size,
    -      sorted_decoded_vocab,
    -      tmp_accepted_indices_,
    -      tmp_rejected_indices_,
    -      tmp_uncertain_indices_
    +
    +  return fill_reject_indices;
    +}
    +
    +AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
    +    size_t vocab_size,
    +    const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
    +    const std::vector<int32_t>& subtree_nodes_range,
    +    bool is_root_rule
    +) {
    +  tmp_accepted_indices_.clear();
    +  tmp_rejected_indices_.clear();
    +  tmp_uncertain_indices_.clear();
    +  // For every character in the current token, stores whether it is possible to reach the end of
    +  // the rule when matching until this character. Store it in a stack for later rollback.
    +  tmp_can_reach_end_stack_.assign({IsCompleted()});
    +  tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()});
    +  std::bitset<256> first_character_mask;
    +  const auto& sequence = grammar_->GetRuleExpr(initial_state.sequence_id);
    +  if (sequence.type == Grammar::Impl::RuleExprType::kSequence) {
    +    const auto& sub_sequence = grammar_->GetRuleExpr(sequence[initial_state.element_id]);
    +    switch (sub_sequence.type) {
    +      case Grammar::Impl::RuleExprType::kByteString: {
    +        first_character_mask[sub_sequence[initial_state.sub_element_id]] = true;
    +        break;
    +      }
    +      case xgrammar::Grammar::Impl::RuleExprType::kCharacterClass:
    +      case xgrammar::Grammar::Impl::RuleExprType::kCharacterClassStar: {
    +        if (initial_state.sub_element_id == 0) {
    +          bool is_negative = sub_sequence[0];
    +          for (int i = 1; i < sub_sequence.size(); i += 2) {
    +            int left_char = static_cast<uint8_t>(sub_sequence[i]);
    +            int right_char = static_cast<uint8_t>(sub_sequence[i + 1]);
    +            for (int c = left_char; c <= right_char; ++c) {
    +              first_character_mask[c] = true;
    +            }
    +          }
    +          if (is_negative) {
    +            first_character_mask = ~first_character_mask;
    +          }
    +          break;
    +        }
    +        // Otherwise, it's matching a UTF-8 character. We can optimize the matching process
    +        // here.
    +        for (size_t i = 0x80; i < 0xC0; ++i) {
    +          first_character_mask[i] = true;
    +        }
    +        break;
    +      }
    +      default: {
    +        XGRAMMAR_LOG(FATAL) << "Unsupported rule expr type: " << static_cast<int>(sequence.type);
    +      }
    +    }
    +  } else {
    +    XGRAMMAR_DCHECK(sequence.type == Grammar::Impl::RuleExprType::kTagDispatch);
    +    first_character_mask.set();
    +  }
    +  bool rejected_indices_are_filled = GetTokenMaskWithFirstCharacterCheck(
    +      sorted_decoded_vocab, first_character_mask, subtree_nodes_range, is_root_rule
       );
    +  if (rejected_indices_are_filled) {
    +    return AdaptiveTokenMask(
    +        vocab_size,
    +        sorted_decoded_vocab,
    +        tmp_accepted_indices_,
    +        tmp_rejected_indices_,
    +        tmp_uncertain_indices_
    +    );
    +  } else {
    +    return AdaptiveTokenMask(
    +        vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
    +    );
    +  }
     }
     
     /******************* GrammarCompiler::Impl *******************/
    @@ -525,71 +810,75 @@ CompiledGrammar GrammarCompiler::Impl::MultiThreadCompileGrammar(Grammar grammar
         adaptive_token_mask_cache_mutex.emplace();
       }
     
    -  auto add_adaptive_token_mask = [&](const StackElement& stack_element, bool is_root_rule) {
    -    auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar, stack_element);
    +  auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
    +    auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar, state, false);
         auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
    -        tokenizer_info_.GetVocabSize(), tokenizer_info_.GetSortedDecodedVocab(), is_root_rule
    +        tokenizer_info_.GetVocabSize(),
    +        tokenizer_info_.GetSortedDecodedVocab(),
    +        tokenizer_info_.GetTrieSubtreeNodesRange(),
    +        is_root_rule
         );
         if (max_threads_ > 1) {
           std::lock_guard<std::mutex> lock(adaptive_token_mask_cache_mutex.value());
    -      compiled_grammar_impl->adaptive_token_mask_cache[stack_element] =
    -          cur_adaptive_token_mask_cache;
    +      compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
         } else {
    -      compiled_grammar_impl->adaptive_token_mask_cache[stack_element] =
    -          cur_adaptive_token_mask_cache;
    +      compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
         }
       };
     
    -  auto add_task_adaptive_token_mask = [&](const StackElement& stack_element, bool is_root_rule) {
    +  auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
         // Execute depending on whether we use thread_pool
         if (max_threads_ > 1) {
    -      thread_pool->Execute([add_adaptive_token_mask, stack_element, is_root_rule]() {
    -        add_adaptive_token_mask(stack_element, is_root_rule);
    +      thread_pool->Execute([add_adaptive_token_mask, state, is_root_rule]() {
    +        add_adaptive_token_mask(state, is_root_rule);
           });
         } else {
    -      add_adaptive_token_mask(stack_element, is_root_rule);
    +      add_adaptive_token_mask(state, is_root_rule);
         }
       };
     
       for (int32_t rule_id = 0; rule_id < static_cast<int>(grammar->NumRules()); ++rule_id) {
    -    auto rule = grammar->GetRule(rule_id);
    -    auto rule_body = grammar->GetRuleExpr(rule.body_expr_id);
    +    const auto& rule = grammar->GetRule(rule_id);
    +    const auto& rule_body = grammar->GetRuleExpr(rule.body_expr_id);
     
         if (rule_body.type == RuleExprType::kTagDispatch) {
    -      auto cur_stack_element = StackElement(rule_id, rule.body_expr_id, 0);
    +      auto state = ParserState(rule_id, rule.body_expr_id, 0, ParserState::kNoPrevInputPos, 0);
           for (int i = 0; i < grammar->root_tag_dispatch_fsm->NumStates(); ++i) {
    -        cur_stack_element.element_id = i;
    -        add_task_adaptive_token_mask(cur_stack_element, rule_id == root_rule_id);
    +        if (!grammar->root_tag_dispatch_fsm->IsEndState(i)) {
    +          state.element_id = i;
    +          add_task_adaptive_token_mask(state, rule_id == root_rule_id);
    +        }
           }
           continue;
         }
     
         XGRAMMAR_DCHECK(rule_body.type == RuleExprType::kChoices);
         for (auto sequence_id : rule_body) {
    -      auto sequence = grammar->GetRuleExpr(sequence_id);
    +      const auto& sequence = grammar->GetRuleExpr(sequence_id);
           if (sequence.type == RuleExprType::kEmptyStr) {
             continue;
           }
           XGRAMMAR_DCHECK(sequence.type == RuleExprType::kSequence);
    +      auto state = ParserState(rule_id, sequence_id, 0, ParserState::kNoPrevInputPos, 0);
           for (int element_id = 0; element_id < sequence.size(); ++element_id) {
    +        state.element_id = element_id;
             auto element = grammar->GetRuleExpr(sequence[element_id]);
             if (element.type == RuleExprType::kRuleRef) {
               continue;
             }
    -        auto cur_stack_element = StackElement(rule_id, sequence_id, element_id);
             if (element.type == RuleExprType::kByteString) {
               for (int idx = 0; idx < element.size(); ++idx) {
    -            cur_stack_element.element_in_string = idx;
    -            add_task_adaptive_token_mask(cur_stack_element, rule_id == root_rule_id);
    +            state.sub_element_id = idx;
    +            add_task_adaptive_token_mask(state, rule_id == root_rule_id);
               }
             } else {
               XGRAMMAR_DCHECK(
                   element.type == RuleExprType::kCharacterClassStar ||
                   element.type == RuleExprType::kCharacterClass
               );
               for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) {
    -            cur_stack_element.left_utf8_bytes = left_utf8_bytes;
    -            add_task_adaptive_token_mask(cur_stack_element, rule_id == root_rule_id);
    +            state.sub_element_id = left_utf8_bytes;
    +            add_task_adaptive_token_mask(state, rule_id == root_rule_id);
               }
             }
           }
    
  • cpp/grammar_matcher.cc+137 125 modified
    @@ -8,10 +8,13 @@
     
     #include <xgrammar/matcher.h>
     
    +#include <cstdint>
    +#include <utility>
    +#include <vector>
    +
     #include "compiled_grammar_data_structure.h"
    +#include "earley_parser.h"
     #include "grammar_data_structure.h"
    -#include "grammar_matcher_base.h"
    -#include "persistent_stack.h"
     #include "support/dynamic_bitset.h"
     #include "support/encoding.h"
     #include "support/int_set.h"
    @@ -21,6 +24,7 @@
     namespace xgrammar {
     
     /******************* Tool functions for token mask *******************/
    +using RuleExprType = Grammar::Impl::RuleExprType;
     
     int32_t GetBitmaskSize(int vocab_size) { return DynamicBitset::GetBufferSize(vocab_size); }
     
    @@ -247,15 +251,15 @@ void ApplyTokenBitmaskInplaceCPU(
      */
     
     /* \brief The concrete implementation of GrammarMatcherNode. */
    -class GrammarMatcher::Impl : public GrammarMatcherBase {
    +class GrammarMatcher::Impl : public EarleyParser {
      public:
       Impl(
           const CompiledGrammar& compiled_grammar,
           std::optional<std::vector<int>> override_stop_tokens = std::nullopt,
           bool terminate_without_stop_token = false,
           int max_rollback_tokens = 0
       )
    -      : GrammarMatcherBase(compiled_grammar->grammar),
    +      : EarleyParser(compiled_grammar->grammar, ParserState::GetInvalidState()),
             compiled_grammar_(compiled_grammar),
             tokenizer_info_(compiled_grammar->tokenizer_info),
             stop_token_ids_(override_stop_tokens.value_or(tokenizer_info_.GetStopTokenIds())),
    @@ -278,17 +282,13 @@ class GrammarMatcher::Impl : public GrammarMatcherBase {
     
       bool IsTerminated() const;
     
    -  void Reset() {
    -    stack_tops_history_.Reset();
    -    token_length_history.clear();
    -    PushInitialState(kInvalidStackElement, true);
    -  }
    +  void Reset() { EarleyParser::Reset(); }
     
       int GetMaxRollbackTokens() const { return max_rollback_tokens_; }
     
       const std::vector<int>& GetStopTokenIds() const { return stop_token_ids_; }
     
    -  std::string _DebugPrintInternalState() const { return PrintStackState(); }
    +  std::string _DebugPrintInternalState() const { return PrintStates(); }
     
      private:
       using StoreType = AdaptiveTokenMask::StoreType;
    @@ -345,24 +345,23 @@ bool GrammarMatcher::Impl::AcceptStopToken() {
       if (terminate_without_stop_token_) {
         return false;
       }
    -  if (!CanReachEnd()) {
    +  if (!IsCompleted()) {
         return false;
       }
    -  stack_tops_history_.PushHistory({});  // Terminate the matcher by setting the stack to empty
    -  token_length_history.push_back(1);  // When rolling back a stop token, we need to rollback 1 state
    +  XGRAMMAR_DCHECK(!stop_token_is_accepted_);
    +  token_length_history.push_back(0);
    +  stop_token_is_accepted_ = true;
       return true;
     }
     
     bool GrammarMatcher::Impl::IsTerminated() const {
       if (terminate_without_stop_token_) {
    -    return CanReachEnd();
    +    return IsCompleted();
       }
       return IsStopTokenAccepted();
     }
     
    -bool GrammarMatcher::Impl::IsStopTokenAccepted() const {
    -  return stack_tops_history_.GetLatest().empty();
    -}
    +bool GrammarMatcher::Impl::IsStopTokenAccepted() const { return stop_token_is_accepted_; }
     
     // TODO(yixin): Polish verbose logging
     bool GrammarMatcher::Impl::AcceptToken(int32_t token_id, bool debug_print) {
    @@ -379,10 +378,15 @@ bool GrammarMatcher::Impl::AcceptToken(int32_t token_id, bool debug_print) {
       }
     
       if (debug_print) {
    -    XGRAMMAR_LOG(INFO) << "Accepting token #" << token_id << "<"
    -                       << PrintAsEscapedUTF8(tokenizer_info_.GetDecodedVocab()[token_id]) << ">";
    +    std::string states_str;
    +    for (const auto& state : GetLatestScanableStates()) {
    +      states_str += "  " + state.ToString() + "\n";
    +    }
    +    XGRAMMAR_LOG(INFO) << "Accepting token id " << token_id << ", string: \""
    +                       << PrintAsEscapedUTF8(tokenizer_info_.GetDecodedVocab()[token_id])
    +                       << "\", current scannable states:\n"
    +                       << states_str;
       }
    -
       // Handle the stop token
       if (std::find(stop_token_ids_.begin(), stop_token_ids_.end(), token_id) !=
           stop_token_ids_.end()) {
    @@ -405,20 +409,19 @@ bool GrammarMatcher::Impl::AcceptToken(int32_t token_id, bool debug_print) {
       const auto& token = tokenizer_info_.GetDecodedVocab()[token_id];
       int pos = 0;
       for (auto char_value : token) {
    -    if (!AcceptChar(char_value, debug_print)) {
    +    if (!Advance(char_value)) {
           if (debug_print) {
             XGRAMMAR_LOG(INFO) << "Token #" << token_id << "<" << PrintAsEscapedUTF8(token)
                                << "> rejected at position " << pos << ", char "
                                << PrintAsEscapedUTF8(char_value);
           }
    -      RollbackChars(pos);
    +      PopLastStates(pos);
           return false;
         }
         ++pos;
       }
       token_length_history.push_back(token.size());
       if (static_cast<int>(token_length_history.size()) > max_rollback_tokens_) {
    -    DiscardEarliestChars(token_length_history.front());
         token_length_history.pop_front();
       }
     
    @@ -440,24 +443,29 @@ bool GrammarMatcher::Impl::AcceptString(const std::string& input_str, bool debug
     
       int accepted_cnt = 0;
       for (auto char_value : input_str) {
    -    if (!AcceptChar(char_value, debug_print)) {
    +    if (!Advance(char_value)) {
           if (debug_print) {
             XGRAMMAR_LOG(INFO) << "String \"" << PrintAsEscapedUTF8(input_str)
                                << "\" rejected at position " << accepted_cnt << ", char "
                                << PrintAsEscapedUTF8(char_value);
           }
    -      RollbackChars(accepted_cnt);
    +      PopLastStates(accepted_cnt);
           return false;
         }
         ++accepted_cnt;
       }
       token_length_history.push_back(input_str.size());
       if (static_cast<int>(token_length_history.size()) > max_rollback_tokens_) {
    -    DiscardEarliestChars(token_length_history.front());
         token_length_history.pop_front();
       }
       if (debug_print) {
    -    XGRAMMAR_LOG(INFO) << "String \"" << PrintAsEscapedUTF8(input_str) << "\" accepted.";
    +    std::string states_str;
    +    for (const auto& state : GetLatestScanableStates()) {
    +      states_str += "  " + state.ToString() + "\n";
    +    }
    +    XGRAMMAR_LOG(INFO) << "String \"" << PrintAsEscapedUTF8(input_str)
    +                       << "\" is accepted. Current scannable states:\n"
    +                       << states_str;
       }
       return true;
     }
    @@ -501,59 +509,59 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
       int32_t* bitmask_data_ptr =
           CheckAndGetBitmaskPtr(*next_token_bitmask, tokenizer_info_.GetVocabSize(), index);
       const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
    +  const auto& subtree_range = tokenizer_info_.GetTrieSubtreeNodesRange();
       const auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
    -  const auto& latest_stack_tops = stack_tops_history_.GetLatest();
    +  // We need to have a copy, because scanable_state_history_ will be modified during the
    +  // FillNextTokenBitmask process, which can lead to undefined behavior.
    +  auto latest_states = GetLatestScanableStates();
     
    -  // We check all the stacks one by one, and find the accepted token set or the rejected token set
    -  // for each stack. We will try to find the small one of the two sets.
    -  // The final accepted token set is the union of the accepted token sets of all stacks.
    -  // The final rejected token set is the intersection of the rejected token sets of all stacks.
    +  // We check all the latest states of the earley parser, and check all the masks of the leaf
    +  // states. The final accepted token set is the union of the accepted token sets of all leaf
    +  // states. The final rejected token set is the intersection of the rejected token sets of all leaf
    +  // states.
     
       // Note these indices store the indices in sorted_decoded_vocab, instead of the token ids.
       tmp_accepted_bitset_.Reset();
       // {-1} means the universal set, i.e. all tokens initially
       tmp_rejected_indices_.assign({-1});
     
    -  // If there is a stack top that is a tag dispatch, we allow special tokens to be accepted
    +  // If there is a leaf ParserState that is a tag dispatch, we allow special tokens to be accepted
       // because in function calling cases, only the part within the tag is constrained
       bool have_tag_dispatch = false;
     
       if (debug_print) {
         XGRAMMAR_LOG(INFO) << "FillNextTokenBitmask: index=" << index
    -                       << ", num of stacks=" << latest_stack_tops.size();
    +                       << ", num of states=" << latest_states.size();
       }
     
    -  int stack_top_cnt = -1;
    -
    -  for (auto top : latest_stack_tops) {
    -    ++stack_top_cnt;
    -    auto cur_stack_element = persistent_stack_[top];
    -    auto cur_sequence = grammar_->GetRuleExpr(cur_stack_element.sequence_id);
    -    if (cur_sequence.type != RuleExprType::kTagDispatch &&
    -        cur_stack_element.parent_id == StackElement::kNoParent &&
    -        cur_stack_element.element_id == cur_sequence.size()) {
    -      continue;
    -    }
    -
    -    if (cur_sequence.type == RuleExprType::kTagDispatch) {
    -      have_tag_dispatch = true;
    +  std::vector<std::pair<ParserState, decltype(adaptive_token_mask_cache.cbegin())>>
    +      latest_states_with_masks;
    +
    +  for (const auto& state : latest_states) {
    +    auto cur_sequence = grammar_->GetRuleExpr(state.sequence_id);
    +    XGRAMMAR_DCHECK(!(
    +        cur_sequence.type == RuleExprType::kRuleRef ||
    +        cur_sequence.type == RuleExprType::kChoices || cur_sequence.type == RuleExprType::kEmptyStr
    +    ));
    +    have_tag_dispatch = cur_sequence.type == RuleExprType::kTagDispatch;
    +    XGRAMMAR_DCHECK(have_tag_dispatch || cur_sequence.type == RuleExprType::kSequence);
    +    auto adaptive_token_mask_it = adaptive_token_mask_cache.find(state);
    +    XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end()) << state;
    +    const auto& adaptive_token_mask = adaptive_token_mask_it->second;
    +    latest_states_with_masks.push_back(std::make_pair(state, adaptive_token_mask_it));
    +    if (adaptive_token_mask.store_type == StoreType::kAcceptedBitset) {
    +      tmp_accepted_bitset_ |= adaptive_token_mask.accepted_bitset;
    +    } else if (adaptive_token_mask.store_type == StoreType::kAccepted) {
    +      for (auto idx : adaptive_token_mask.accepted_indices) {
    +        tmp_accepted_bitset_.Set(sorted_decoded_vocab[idx].first, true);
    +      }
         }
    +  }
     
    -    auto adaptive_token_mask_it = adaptive_token_mask_cache.find(cur_stack_element);
    -    XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end())
    -        << "The adaptive token mask is not found for stack element: "
    -        << persistent_stack_.PrintStackElement(cur_stack_element);
    -
    +  for (const auto& [state, adaptive_token_mask_it] : latest_states_with_masks) {
         const auto& adaptive_token_mask = adaptive_token_mask_it->second;
     
    -    if (debug_print) {
    -      XGRAMMAR_LOG(INFO) << "FillNextTokenBitmask: Stack #" << stack_top_cnt
    -                         << ", num_uncertain_tokens="
    -                         << adaptive_token_mask.uncertain_indices.size() << ": "
    -                         << persistent_stack_.PrintStackByTopId(top) << "\n";
    -    }
    -
    -    // For each stack, we will check every uncertain token and put them into the accepted or
    +    // For each ParserState, we will check every uncertain token and put them into the accepted or
         // rejected list.
     
         // Step 2. Update the accepted tokens in accepted_indices_delta, or the rejected tokens in
    @@ -564,13 +572,31 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
     
         tmp_rejected_indices_delta_.clear();
     
    -    // Examine only the current one stack
    -    stack_tops_history_.PushHistory({persistent_stack_.NewNode(cur_stack_element)});
    +    // Examine only the current one ParserState
    +    PushOneStateToCheck(state);
     
         const std::string* prev_token = nullptr;
         int prev_matched_size = 0;
    +    if (debug_print) {
    +      XGRAMMAR_LOG(INFO) << "The ParserState is " << state << ", the mask is "
    +                         << adaptive_token_mask.Print(tokenizer_info_);
    +    }
    +    int last_rejected_uncertain_range = 0;
    +    for (const auto& cur_token_idx : adaptive_token_mask.uncertain_indices) {
    +      // Check if the current token is already accepted. If it is, we can skip it.
    +      if (tmp_accepted_bitset_[sorted_decoded_vocab[cur_token_idx].first]) {
    +        continue;
    +      }
    +
    +      // Check if the current token is in the rejected range. i.e. check if the current token
    +      // is on the subtree of the rejected token.
    +      if (cur_token_idx < last_rejected_uncertain_range) {
    +        if (adaptive_token_mask.store_type == StoreType::kRejected) {
    +          tmp_rejected_indices_delta_.push_back(cur_token_idx);
    +        }
    +        continue;
    +      }
     
    -    for (auto cur_token_idx : adaptive_token_mask.uncertain_indices) {
           const auto& cur_token = sorted_decoded_vocab[cur_token_idx].second;
           bool accepted = true;
     
    @@ -583,17 +609,19 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
                               .first -
                           cur_token.begin();
             if (lcp_len > prev_matched_size) {
    +          last_rejected_uncertain_range = subtree_range[cur_token_idx];
               accepted = false;
             } else if (lcp_len < prev_matched_size) {
    -          RollbackChars(prev_matched_size - lcp_len);
    +          PopLastStates(prev_matched_size - lcp_len);
             }
             prev_matched_size = std::min(prev_matched_size, lcp_len);
           }
     
           // Step 2.2. Find if the current token is accepted or rejected.
           if (accepted) {
             for (int j = prev_matched_size; j < static_cast<int>(cur_token.size()); ++j) {
    -          if (!AcceptChar(cur_token[j], false)) {
    +          if (!Advance(cur_token[j])) {
    +            last_rejected_uncertain_range = subtree_range[cur_token_idx];
                 accepted = false;
                 break;
               }
    @@ -616,16 +644,9 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
           prev_token = &cur_token;
         }
     
    -    RollbackChars(prev_matched_size + 1);
    -
    +    PopLastStates(prev_matched_size + 1);
         // Step 3. Update the accepted_indices or rejected_indices
    -    if (adaptive_token_mask.store_type == StoreType::kAcceptedBitset) {
    -      tmp_accepted_bitset_ |= adaptive_token_mask.accepted_bitset;
    -    } else if (adaptive_token_mask.store_type == StoreType::kAccepted) {
    -      for (auto idx : adaptive_token_mask.accepted_indices) {
    -        tmp_accepted_bitset_.Set(sorted_decoded_vocab[idx].first, true);
    -      }
    -    } else {
    +    if (adaptive_token_mask.store_type == StoreType::kRejected) {
           // rejected_indices = Intersect(
           //     rejected_indices,
           //     adaptive_token_mask.rejected_indices + rejected_indices_delta)
    @@ -635,7 +656,7 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
       }
     
       // Finally update the rejected_ids bitset
    -  bool can_reach_end = CanReachEnd();
    +  bool can_reach_end = IsCompleted();
       SetTokenBitmask(
           bitmask_data_ptr,
           tmp_accepted_bitset_,
    @@ -659,53 +680,56 @@ std::string GrammarMatcher::Impl::FindJumpForwardString() {
       bool can_find_next_char = true;
     
       while (can_find_next_char) {
    -    const auto& stack_tops = stack_tops_history_.GetLatest();
    +    const auto& states = scanable_state_history_[scanable_state_history_.size() - 1];
     
    -    // 1. Check that for every stack top, the next possible char is unique and the same
    +    // 1. Check that for every leaf ParserState, the next possible char is unique and the same
         // -1 means not found yet; 0~255 means the next char
         int next_char = -1;
    -    for (auto stack_top : stack_tops) {
    -      auto stack_element = persistent_stack_[stack_top];
    -      auto cur_sequence = grammar_->GetRuleExpr(stack_element.sequence_id);
    -
    -      // We cannot deduce the next char for tag dispatch
    -      if (cur_sequence.type == RuleExprType::kTagDispatch) {
    +    for (const auto& ParserState : states) {
    +      if (IsCompleted()) {
             can_find_next_char = false;
    -        continue;
    +        break;
           }
    -
    -      // The state comes to the end of the grammar
    -      if (stack_element.parent_id == StackElement::kNoParent &&
    -          stack_element.element_id == cur_sequence.size()) {
    +      auto cur_sequence = grammar_->GetRuleExpr(ParserState.sequence_id);
    +      // We cannot deduce the next char for tag dispatch
    +      if (cur_sequence.type == RuleExprType::kTagDispatch) {
             can_find_next_char = false;
             break;
           }
    -
    -      auto cur_element = grammar_->GetRuleExpr(cur_sequence[stack_element.element_id]);
    -
    +      // The ParserState comes to the end of the grammar
    +      XGRAMMAR_DCHECK(ParserState.element_id != cur_sequence.size());
    +      XGRAMMAR_DCHECK(
    +          cur_sequence.type != RuleExprType::kChoices &&
    +          cur_sequence.type != RuleExprType::kEmptyStr
    +      );
    +      const auto& cur_element = grammar_->GetRuleExpr(cur_sequence[ParserState.element_id]);
           if (cur_element.type == RuleExprType::kByteString) {
    -        XGRAMMAR_DCHECK(stack_element.element_in_string < cur_element.size());
    +        XGRAMMAR_DCHECK(ParserState.sub_element_id < cur_element.size());
             if (next_char == -1) {
    -          next_char = cur_element[stack_element.element_in_string];
    -        } else if (next_char != cur_element[stack_element.element_in_string]) {
    -          can_find_next_char = false;
    -          break;
    -        }
    -      } else {
    -        XGRAMMAR_DCHECK(
    -            cur_element.type == RuleExprType::kCharacterClass ||
    -            cur_element.type == RuleExprType::kCharacterClassStar
    -        );
    -        if (stack_element.left_utf8_bytes > 0 || cur_element.size() != 3 || cur_element[0] != 0 ||
    -            cur_element[1] != cur_element[2]) {
    -          can_find_next_char = false;
    -          break;
    -        } else if (next_char == -1) {
    -          next_char = cur_element[1];
    -        } else if (next_char != cur_element[1]) {
    +          next_char = cur_element[ParserState.sub_element_id];
    +        } else if (next_char != cur_element[ParserState.sub_element_id]) {
               can_find_next_char = false;
               break;
             }
    +        continue;
    +      }
    +      if (cur_element.type == RuleExprType::kRuleRef) {
    +        continue;
    +      }
    +
    +      XGRAMMAR_DCHECK(
    +          cur_element.type == RuleExprType::kCharacterClass ||
    +          cur_element.type == RuleExprType::kCharacterClassStar
    +      );
    +      if (ParserState.sub_element_id > 0 || cur_element.size() != 3 || cur_element[0] != 0 ||
    +          cur_element[1] != cur_element[2]) {
    +        can_find_next_char = false;
    +        break;
    +      } else if (next_char == -1) {
    +        next_char = cur_element[1];
    +      } else if (next_char != cur_element[1]) {
    +        can_find_next_char = false;
    +        break;
           }
         }
     
    @@ -716,25 +740,13 @@ std::string GrammarMatcher::Impl::FindJumpForwardString() {
         // 2. If found, accept the char and iterate to the next position
         if (can_find_next_char) {
           result += static_cast<uint8_t>(next_char);
    -
    -      tmp_new_stack_tops_.clear();
    -      for (auto stack_top : stack_tops) {
    -        auto cur_stack_element = persistent_stack_[stack_top];
    -        auto new_stack_element = AdvanceStackElementWithChar(cur_stack_element, next_char);
    -
    -        if (new_stack_element == cur_stack_element) {
    -          ExpandEquivalentStackElements(new_stack_element, &tmp_new_stack_tops_, stack_top);
    -        } else {
    -          ExpandEquivalentStackElements(new_stack_element, &tmp_new_stack_tops_);
    -        }
    -      }
    -      stack_tops_history_.PushHistory(tmp_new_stack_tops_);
    +      Advance(next_char);
           ++num_accepted_chars;
         }
       }
     
       // Rollback all chars accepted
    -  RollbackChars(num_accepted_chars);
    +  PopLastStates(num_accepted_chars);
       return result;
     }
     
    @@ -744,7 +756,7 @@ void GrammarMatcher::Impl::Rollback(int num_tokens) {
           << token_length_history.size() << " steps of history are saved";
       while (num_tokens > 0) {
         int steps = token_length_history.back();
    -    RollbackChars(steps);
    +    PopLastStates(steps);
         token_length_history.pop_back();
         --num_tokens;
       }
    
  • cpp/serialize_json.cc+1 1 modified
    @@ -20,7 +20,7 @@
     
     namespace xgrammar {
     
    -static constexpr const char kXGrammarSerializeVersion[] = "v1";
    +static constexpr const char kXGrammarSerializeVersion[] = "v2";
     
     bool TokenizerInfo::Impl::operator==(const TokenizerInfo::Impl& other) const {
       static constexpr auto tie = [](const TokenizerInfo::Impl& impl) {
    
  • cpp/support/csr_array.h+17 7 modified
    @@ -113,21 +113,21 @@ class CSRArray {
        * \param data_len Length of the data to be inserted.
        * \return The index of the newly inserted row.
        */
    -  int32_t Insert(const DataType* new_data, int32_t new_data_len);
    +  int32_t PushBack(const DataType* new_data, int32_t new_data_len);
     
       /*!
        * \brief Insert a new row of data into the CSRArray from a vector.
        * \param data Vector containing the data to be inserted.
        * \return The index of the newly inserted row.
        */
    -  int32_t Insert(const std::vector<DataType>& new_data);
    +  int32_t PushBack(const std::vector<DataType>& new_data);
     
       /*!
        * \brief Insert a new row of data into the CSRArray from a Row struct.
        * \param row The Row struct containing the data to be inserted.
        * \return The index of the newly inserted row.
        */
    -  int32_t Insert(const Row& row) { return Insert(row.data, row.data_len); }
    +  int32_t PushBack(const Row& row) { return PushBack(row.data, row.data_len); }
     
       /*!
        * \brief Insert a new row of non-contiguous data into the CSRArray. This method inserts a
    @@ -138,7 +138,17 @@ class CSRArray {
        * \param data_2_len Length of the remaining data to be inserted.
        * \return The index of the newly inserted row.
        */
    -  int32_t InsertNonContiguous(DataType data_1, const DataType* data_2, int32_t data_2_len);
    +  int32_t PushBackNonContiguous(DataType data_1, const DataType* data_2, int32_t data_2_len);
    +
    +  /*!
    +   * \brief Pop back the last one or multiple rows of the CSRArray.
    +   * \param cnt The number of rows to be popped.
    +   */
    +  void PopBack(const int32_t& cnt) {
    +    indptr_.erase(indptr_.end() - cnt, indptr_.end());
    +    data_.erase(data_.begin() + indptr_.back(), data_.end());
    +    return;
    +  }
     
       /****************** Internal Accessors ******************/
     
    @@ -178,7 +188,7 @@ inline typename CSRArray<DataType>::Row CSRArray<DataType>::operator[](int32_t i
     }
     
     template <typename DataType>
    -inline int32_t CSRArray<DataType>::Insert(const DataType* new_data, int32_t new_data_len) {
    +inline int32_t CSRArray<DataType>::PushBack(const DataType* new_data, int32_t new_data_len) {
       // TODO(yixin): whether to add a additional data_len
       // If the new data is already in the CSRArray, we need to copy it to the new memory location.
       if (new_data >= data_.data() && new_data < data_.data() + data_.size()) {
    @@ -192,14 +202,14 @@ inline int32_t CSRArray<DataType>::Insert(const DataType* new_data, int32_t new_
     }
     
     template <typename DataType>
    -inline int32_t CSRArray<DataType>::Insert(const std::vector<DataType>& new_data) {
    +inline int32_t CSRArray<DataType>::PushBack(const std::vector<DataType>& new_data) {
       data_.insert(data_.end(), new_data.begin(), new_data.end());
       indptr_.push_back(static_cast<int32_t>(data_.size()));
       return static_cast<int32_t>(indptr_.size()) - 2;
     }
     
     template <typename DataType>
    -inline int32_t CSRArray<DataType>::InsertNonContiguous(
    +inline int32_t CSRArray<DataType>::PushBackNonContiguous(
         DataType data_1, const DataType* data_2, int32_t data_2_len
     ) {
       if (data_2 >= data_.data() && data_2 < data_.data() + data_.size()) {
    
  • cpp/tokenizer_info.cc+23 0 modified
    @@ -10,6 +10,7 @@
     #include <array>
     #include <memory>
     #include <optional>
    +#include <stack>
     #include <string>
     #include <unordered_map>
     #include <vector>
    @@ -293,6 +294,24 @@ TokenizerInfo::Impl::Impl(
         return a.second < b.second;
       };
       std::sort(sorted_decoded_vocab_.begin(), sorted_decoded_vocab_.end(), f_compare_token);
    +
    +  // The value means: the subtree is [i, trie_subtree_nodes_range[i]).
    +  trie_subtree_nodes_range.resize(sorted_decoded_vocab_.size(), 0);
    +  std::stack<std::pair<std::string, int32_t>> prefix_stack;
    +  for (size_t i = 0; i < sorted_decoded_vocab_.size(); ++i) {
    +    const auto& token = sorted_decoded_vocab_[i].second;
    +    while ((!prefix_stack.empty()) && (token.find(prefix_stack.top().first) == std::string::npos)) {
    +      const auto& top_pair = prefix_stack.top();
    +      trie_subtree_nodes_range[top_pair.second] = i;
    +      prefix_stack.pop();
    +    }
    +    prefix_stack.push({token, i});
    +  }
    +  while (!prefix_stack.empty()) {
    +    const auto& top_pair = prefix_stack.top();
    +    trie_subtree_nodes_range[top_pair.second] = sorted_decoded_vocab_.size();
    +    prefix_stack.pop();
    +  }
     }
     
     std::string TokenizerInfo::Impl::DumpMetadata() const {
    @@ -389,6 +408,10 @@ const std::vector<std::pair<int32_t, std::string>>& TokenizerInfo::GetSortedDeco
       return pimpl_->GetSortedDecodedVocab();
     }
     
    +const std::vector<int32_t>& TokenizerInfo::GetTrieSubtreeNodesRange() const {
    +  return pimpl_->GetTrieSubtreeNodesRange();
    +}
    +
     std::string TokenizerInfo::DumpMetadata() const { return pimpl_->DumpMetadata(); }
     
     TokenizerInfo TokenizerInfo::FromVocabAndMetadata(
    
  • cpp/tokenizer_info_impl.h+4 0 modified
    @@ -35,6 +35,7 @@ class TokenizerInfo::Impl {
       const std::vector<std::pair<int32_t, std::string>>& GetSortedDecodedVocab() const {
         return sorted_decoded_vocab_;
       }
    +  const std::vector<int32_t>& GetTrieSubtreeNodesRange() const { return trie_subtree_nodes_range; }
     
       std::string DumpMetadata() const;
     
    @@ -61,6 +62,9 @@ class TokenizerInfo::Impl {
       /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to
        * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */
       std::vector<std::pair<int32_t, std::string>> sorted_decoded_vocab_;
    +  /*! \brief A pesudo-trie. trie_subtree_nodes_range[i] stores how many nodes there are in the
    +   * subtree. */
    +  std::vector<int32_t> trie_subtree_nodes_range;
       /*! \brief The stop tokens. When the GrammarMatcher can reach the end of the grammar,
        * stop tokens can be accepted. */
       std::vector<int32_t> stop_token_ids_;
    
  • include/xgrammar/tokenizer_info.h+1 0 modified
    @@ -39,6 +39,7 @@ class TokenizerInfo {
       const std::vector<int32_t>& GetStopTokenIds() const;
       const std::vector<int32_t>& GetSpecialTokenIds() const;
       const std::vector<std::pair<int32_t, std::string>>& GetSortedDecodedVocab() const;
    +  const std::vector<int32_t>& GetTrieSubtreeNodesRange() const;
       std::string DumpMetadata() const;
     
       static TokenizerInfo FromVocabAndMetadata(
    
  • python/xgrammar/matcher.py+20 10 modified
    @@ -3,6 +3,7 @@
     """
     
     import math
    +import warnings
     from typing import List, Optional, Tuple, Union
     
     import torch
    @@ -175,28 +176,37 @@ def __init__(
             *,
             override_stop_tokens: Optional[Union[int, List[int]]] = None,
             terminate_without_stop_token: bool = False,
    -        max_rollback_tokens: int = 0,
    +        max_rollback_tokens: int = -1,
         ) -> None:
             """Construct the grammar matcher.
     
    -        Parameters
    -        ----------
    -        compiled_grammar : CompiledGrammar
    -            The initialization context for the grammar matcher.
    +            Parameters
    +            ----------
    +            compiled_grammar : CompiledGrammar
    +                The initialization context for the grammar matcher.
    +
    +            override_stop_tokens : Optional[Union[int, List[int]]], default: None
    +                If not None, the stop tokens to override the ones in the grammar.
     
    -        override_stop_tokens : Optional[Union[int, List[int]]], default: None
    -            If not None, the stop tokens to override the ones in the grammar.
    +            terminate_without_stop_token : bool, default: False
    +                Whether to terminate the matcher without accepting a stop token.
     
    -        terminate_without_stop_token : bool, default: False
    -            Whether to terminate the matcher without accepting a stop token.
    +        max_rollback_tokens : int, default: -1
    +            Deprecated because the earley parser significantly reduces the number of states, so not
    +            needed anymore.
     
    -        max_rollback_tokens : int, default: 0
                 The maximum number of rollback tokens allowed. The rollback operation is useful for
                 jump-forward decoding and speculative decoding.
             """
             if not isinstance(compiled_grammar, CompiledGrammar):
                 raise ValueError("The grammar should be compiled before passing it to GrammarMatcher.")
     
    +        if not max_rollback_tokens == -1:
    +            warnings.warn(
    +                "max_rollback_tokens is deprecated because the earley parser significantly reduces the number of states, so not needed anymore.",
    +                DeprecationWarning,
    +            )
    +
             if isinstance(override_stop_tokens, int):
                 override_stop_tokens = [override_stop_tokens]
     
    
  • tests/cpp/test_serialization.cc+2 2 modified
    @@ -44,8 +44,8 @@ TEST(XGrammarReflectionTest, JSONSerialization) {
     
       // CSRArray use a data_ and indptr_ structure
       auto array = CSRArray<int>{};
    -  array.Insert({0, 1, 2, 3});
    -  array.Insert({4, 5, 6, 7});
    +  array.PushBack({0, 1, 2, 3});
    +  array.PushBack({4, 5, 6, 7});
       auto deserialized_array = CSRArray<int>{};
     
       auto json_array = AutoSerializeJSONValue(array);
    
  • tests/python/test_grammar_matcher_ebnf.py+94 1 modified
    @@ -11,7 +11,11 @@
     from transformers import AutoTokenizer
     
     import xgrammar as xgr
    -from xgrammar.testing import _get_masked_tokens_from_bitmask, _is_grammar_accept_string
    +from xgrammar.testing import (
    +    _get_masked_tokens_from_bitmask,
    +    _get_matcher_from_grammar_and_tokenizer_info,
    +    _is_grammar_accept_string,
    +)
     
     
     def test_simple():
    @@ -399,5 +403,94 @@ def test_fill_next_token_bitmask(
         assert len(rejected_token_ids) == expected_rejected_sizes[-1]
     
     
    +def test_nullable_grammar():
    +    grammar_with_nullable_rules = """
    +    root ::= rule1 | (rule1 rule1 rule1 rule3)+
    +    rule1 ::= rule2
    +    rule2 ::= [0-9]*
    +    rule3 ::= [a-z]
    +"""
    +    test_string = ["abc12312398014a", ""]
    +
    +    for s in test_string:
    +        assert _is_grammar_accept_string(grammar_with_nullable_rules, s)
    +
    +
    +def test_predict_complete():
    +    # Test complex prediction and completion with EBNF grammar.
    +    mixed_grammar_str = """root ::= rule1 [0-9]?
    +    rule1 ::= rule2 [0-9]? | rule4 [0-9]?
    +    rule2 ::= rule3 [0-9]? | rule2 [0-9]? | rule1 [0-9]?
    +    rule3 ::= rule4 [0-9]? | rule5 [0-9]?
    +    rule4 ::= rule5 [0-9]? | rule6 [0-9]?
    +    rule5 ::= rule6 [0-9]? | rule7 [0-9]? | rule8 [0-9]?
    +    rule6 ::= rule7 [0-9]? | rule1 [0-9]?
    +    rule7 ::= rule8 [0-9]? | rule9 [0-9]?
    +    rule8 ::= rule9 [0-9]? | rule7 [0-9]?
    +    rule9 ::= [0-9]?
    +    """
    +
    +    grammar = xgr.Grammar.from_ebnf(mixed_grammar_str)
    +    input_str = ""
    +    for i in range(10):
    +        assert _is_grammar_accept_string(grammar, input_str)
    +        input_str += "0"
    +    assert _is_grammar_accept_string(grammar, input_str)
    +
    +    # Test right recursion
    +    right_recursion_grammar = "root ::= [a-z] root | [a-z]"
    +
    +    accept_strings = ["a", "ab", "abc", "abcd", "abcde"]
    +    reject_strings = ["", "1", "a1", "ab1", "abc1"]
    +    for accept_string in accept_strings:
    +        assert _is_grammar_accept_string(right_recursion_grammar, accept_string)
    +    for reject_string in reject_strings:
    +        assert not _is_grammar_accept_string(right_recursion_grammar, reject_string)
    +
    +    # Test the mixture of right recursion and other rules
    +    mixed_grammar_str = """root ::= rule1
    +    rule1 ::= "{" rule2 | ""
    +    rule2 ::= root "}"
    +    """
    +    test_strings = {"", "{}", "{{}}", "{{{}}}", "{{{{}}}}", "{{{{{}}}}}"}
    +    rejected_strings = {"{", "{}{}", "{{{{}", "{{}}}", "{{{{{}}}}}}"}
    +
    +    for test_string in test_strings:
    +        assert _is_grammar_accept_string(mixed_grammar_str, test_string)
    +    for rejected_string in rejected_strings:
    +        assert not _is_grammar_accept_string(mixed_grammar_str, rejected_string)
    +
    +
    +def test_advance():
    +    # Test complex Advance and completion with EBNF grammar.
    +    ebnf_grammar_str = """root ::= rule1
    +    rule1 ::= [a] | [a-b] | [a-c]* | "a" | "aaaaaaaaaaaaaaaaaaa"
    +    """
    +    grammar = xgr.Grammar.from_ebnf(ebnf_grammar_str)
    +    for i in range(10):
    +        input_str = "a" * i
    +        assert _is_grammar_accept_string(grammar, input_str)
    +
    +
    +def test_character_class_star_utf8():
    +    ebnf_grammar_str = """root ::= [^0-9]*"""
    +    test_string = "worldせかい世界"
    +    assert _is_grammar_accept_string(ebnf_grammar_str, test_string)
    +
    +
    +@pytest.mark.hf_token_required
    +def test_not_neighbour_character_class():
    +    raw_grammar = "root ::= [a-cx-z]*"
    +    tokenizer_path = "meta-llama/Llama-2-7b-chat-hf"
    +    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True)
    +    tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
    +    grammar = xgr.Grammar.from_ebnf(raw_grammar)
    +    matcher = _get_matcher_from_grammar_and_tokenizer_info(grammar, tokenizer_info)
    +    token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
    +    matcher.fill_next_token_bitmask(token_bitmask)
    +    rejected_token_ids = _get_masked_tokens_from_bitmask(token_bitmask, tokenizer_info.vocab_size)
    +    assert len(rejected_token_ids) == 31933
    +
    +
     if __name__ == "__main__":
         pytest.main(sys.argv)
    
  • tests/python/test_grammar_matcher_json_schema.py+22 1 modified
    @@ -5,7 +5,7 @@
     
     import pytest
     from pydantic import BaseModel, Field
    -from transformers import AutoTokenizer
    +from transformers import AutoConfig, AutoTokenizer
     
     import xgrammar as xgr
     from xgrammar.testing import (
    @@ -490,5 +490,26 @@ class MainModel(BaseModel):
         assert matcher.is_terminated()
     
     
    +@pytest.mark.hf_token_required
    +def test_implicit_left_recursion_schema():
    +    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    +    tokenizer = AutoTokenizer.from_pretrained(model_name)
    +    config = AutoConfig.from_pretrained(model_name)
    +
    +    json_schema = {
    +        "$schema": "http://json-schema.org/draft-04/schema#",
    +        "type": "object",
    +        "properties": {
    +            "url": {
    +                "type": "string",
    +                "pattern": "^(https?://)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([/\\w \\.-]*)*/?",
    +            }
    +        },
    +    }
    +    tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size)
    +    grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
    +    _ = grammar_compiler.compile_json_schema(schema=json.dumps(json_schema))
    +
    +
     if __name__ == "__main__":
         pytest.main(sys.argv)
    
  • tests/python/test_recursion_depth.py+2 2 modified
    @@ -36,6 +36,7 @@ def test_error_set_recursion_depth():
     
     
     def test_recursion_exceed():
    +    # In Earley Parser, the recursion depth can't be exceeded.
         with xgr.max_recursion_depth(1000):
             grammar_ebnf = r"""
         root ::= "\"" basic_string "\""
    @@ -46,8 +47,7 @@ def test_recursion_exceed():
     
             matcher = _get_matcher_from_grammar(grammar_ebnf)
     
    -        with pytest.raises(RuntimeError):
    -            matcher.accept_string(input_str)
    +        matcher.accept_string(input_str)
     
     
     if __name__ == "__main__":
    
  • tests/python/test_serializer.py+62 0 modified
    @@ -1,3 +1,4 @@
    +import sys
     from typing import Any, List, Literal, Optional, Tuple
     
     import pytest
    @@ -209,3 +210,64 @@ def test_serializer_correctness_functional(
         rejected_sizes.append(len(rejected_token_ids))
         if expected_rejected_sizes is not None:
             assert rejected_sizes[-1] == expected_rejected_sizes[-1]
    +
    +
    +def test_serializer_correctness_in_mask_cache():
    +
    +    # test masks in grammar serialization
    +    tokenizer = xgr.TokenizerInfo(["1", "212", "a", "A", "b", "一", "-", "aBc", "abc"])
    +    test_grammar = """
    +        root ::= rule1 | rule2
    +        rule1 ::= [^0-9] rule1
    +        rule2 ::= ("AB" | "1" | "") rule2
    +    """
    +    expected_mask = "[[1,7,0,-1,0],[1,[],[1,2],[1,0,0,0],[]],[1,7,0,-1,1],[0,[],\
    +[],[1,0,0,0],[]],[1,7,0,-1,2],[0,[],[],[1,0,0,0],[]],[1,7,0,-1,3],[0,[],[],[1,0,\
    +0,0],[]],[3,15,0,-1,0],[0,[3],[],[1,0,0,0],[]],[3,15,0,-1,1],[0,[],[],[1,0,0,0],\
    +[]],[3,17,0,-1,0],[0,[1],[],[1,0,0,0],[]]]"
    +
    +    grammar_compiler = xgr.GrammarCompiler(tokenizer)
    +    compiled_grammar = grammar_compiler.compile_grammar(test_grammar)
    +    serial_json = compiled_grammar.serialize_json()
    +    mask_idx = serial_json.find('"adaptive_token_mask_cache":')
    +    assert mask_idx != -1
    +    mask_idx += len('"adaptive_token_mask_cache":')
    +    assert serial_json[mask_idx : mask_idx + len(expected_mask)] == expected_mask, (
    +        serial_json[mask_idx : mask_idx + len(expected_mask)],
    +        expected_mask,
    +    )
    +
    +    # test bitset
    +    test_list = []
    +    for i in range(201):
    +        test_list.append("1")
    +    for i in range(201):
    +        test_list.append("2")
    +    tokenizer = xgr.TokenizerInfo(test_list)
    +
    +    test_grammar = 'root ::= "1"'
    +    expected_mask = "[[0,1,0,-1,0],[2,[],[],[1,201,402,13,0,1,2,3,4,5,6,7,8,9,10\
    +,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,3\
    +7,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,\
    +64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90\
    +,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,\
    +113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,\
    +133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,\
    +153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,\
    +173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,\
    +193,194,195,196,197,198,199,200],[]]]"
    +
    +    grammar_compiler = xgr.GrammarCompiler(tokenizer)
    +    compiled_grammar = grammar_compiler.compile_grammar(test_grammar)
    +    serial_json = compiled_grammar.serialize_json()
    +    mask_idx = serial_json.find('"adaptive_token_mask_cache":')
    +    assert mask_idx != -1
    +    mask_idx += len('"adaptive_token_mask_cache":')
    +    assert serial_json[mask_idx : mask_idx + len(expected_mask)] == expected_mask, (
    +        serial_json[mask_idx : mask_idx + len(expected_mask)],
    +        expected_mask,
    +    )
    +
    +
    +if __name__ == "__main__":
    +    pytest.main(sys.argv)
    

Vulnerability mechanics

Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

5

News mentions

0

No linked articles in our index yet.