A Reasoning First LLM Framework: How do train a LLM to reason?
Published:
P.S. This is my first draft. Article likely to go through more iterations and proofreading.
One of the topics that I have got very interested in reading about is the reasoning capabilities in Large Language Models (LLMs). In this article, I’ll explore several key techniques speculated within the machine learning community to have been used in training models like OpenAI’s o1 and other state-of-the-art LLMs on the market today. Since o1 is closed-source, the exact details remain unknown, but this piece offers a speculative overview, with references to relevant research papers.
My goal here is to introduce three core techniques that are instrumental in helping LLMs ‘develop’ reasoning skills.
Can the Artificial Intelligence community establish a framework that enables LLMs to think and reason just like humans do?
What is reasoning? System 1 vs System 2 Thinking
To better understand reasoning, we can categorise it into two systems of thinking known as ‘System 1’ and ‘System 2’ as introduced by Daniel Kahneman in his book ‘Thinking, Fast and Slow[https://thedecisionlab.com/reference-guide/philosophy/system-1-and-system-2-thinking]’. System 1 thinking is fast, automatic, and intuitive, operating with little to no effort. This mode of thinking allows us to make quick decisions and judgments based on patterns and experiences. In contrast, System 2 is slow, deliberate, and conscious, requiring intentional effort. This type of thinking is useful when there is a need to do complex problem-solving and analysis.
The relationship between System 1 and System 2 thinking is intricate and interdependent. System 1 thinking is crucial for the functioning of System 2 thinking. The rapid assessments made by System 1 lay the groundwork for the more deliberate analyses performed by System 2.
Deep learning models, particularly large language models (LLMs), have a substantial capacity for memorization. This allows them to learn vast amounts of knowledge across diverse topics given that their training data includes gigantic corpus from the internet. As francois chollet, a LLM researcher from Google says it, This gives them an extensive “database” of information and program structures that it can pick and choose from and reapply to solve new problems.
An Important Hypothesis about Hallucinations
Large language models are susceptible to hallucinations, a tendency often attributed to their autoregressive nature. LLMs generate each subsequent word by selecting the token with the highest probability. If the model inadvertently chooses an illogical or nonsensical token, the next token it generates is then based on that erroneous choice. This can lead to a compounding effect, where each new word builds on the prior mistake, causing the error to worsen progressively. Without the ability to backtrack, the model’s output can rapidly diverge from coherence.
Yann LeCun describes reasoning in LLMs as still in its early stages, noting that the computational effort per token remains constant in each use.
To address the limitations of System 1 thinking and reduce susceptibility to hallucinations in LLMs, models usually undergo additional stages following the pre-training phase. These stages, known as Supervised Fine-tuning and Reinforcement Learning with Feedback, provide opportunities to teach the model reasoning skills.
Technique 1: Supervised Fine-Tuning
(https://ss8319.github.io/_posts/o1_excerpt.png) Figure shows the STaR methodology in the paper, “STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning”
Eric Zelikman et al. introduced the STaR method, where large language models (LLMs) are prompted to generate both a rationale and an answer. If the model’s answer is incorrect, it is provided with a hint to guide it toward the correct solution. The triplets of the question, rationale, and correct answer, are collecte to fine-tune the LLM, improving its reasoning capabilities. Since then there have been many variations of this technique with the core idea of fine-tuning a LLM on rationales and reasoning traces.
Technique 2: Reinforcement Learning from Human Feedback (RLHF)
(https://ss8319.github.io/_posts/STAR.png) The excerpt above comes from OpenAI’s own blog about their o1 model. We know that OpenAI is leveraging large scale Reinforcement Learning to teach o1 to reason.
In the stage known as LLM alignment, Reinforcement Learning from Human Feedback (RLHF) is employed. Here, a reward model is employed to provide a scalar feedback value, which ranks the model’s output. This reward is crucial for optimizing the model’s behavior toward the desired outcome. The reward model is typically trained based on human preferences, where human annotators rank or score different outputs generated by the model for a given prompt.
(https://ss8319.github.io/_posts/RLHF.png)
Proximal Policy Optimization (PPO), a popular reinforcement learning algorithm, is often used in this stage. PPO helps optimize the model’s policy by iteratively improving it while ensuring that changes to the model are not too drastic, preventing the model from deviating too far from what is known to work. The central idea in PPO is to update the model’s parameters in a way that maximizes the reward while controlling for how much the new policy diverges from the old one. It achieves this by using a clipped objective function, ensuring the updates remain within a certain trust region, balancing exploration (trying new things) and exploitation (using what has worked well).
The steps generally include:
- The model generates outputs based on the current policy.
- The reward model provides feedback on these outputs.
- PPO adjusts the policy by maximizing the reward, making sure the model improves gradually without large, unstable changes.
- The process repeats, with the model refining its behavior based on human-guided feedback until the desired level of performance is achieved.
The feedback from RLHF can take two forms: outcome supervision, which evaluates the model’s final result, or process supervision, which evaluates intermediate reasoning steps. Hunter Lightman et al demonstrated that process supervision, where feedback is given on each intermediate reasoning step, significantly outperforms outcome supervision, which only evaluates the final result, in solving complex tasks like those in the MATH dataset. By training a Process Supervision Reward Model (PRM) to assess each step, the model can more precisely learn from errors, improving problem-solving accuracy. This finer-resolution feedback, though expensive due to human annotation of each step, allows for a more effective method of tackling “convincing wrong-answer” solutions.
Technique 3: Scaling Inference
Well, technically we aren’t teaching the LLM to reason anymore. Instead, we are leveraging more compute to scale the inference of the LLM. You can think about this as, “converting inference time for more accuracy”. There are many papers already about the best methods to prompt LLMs in different scenarios. In my experience, they are some variant of Chain of Thought (CoT) prompting.
However, let’s review a highly-popular algorithm called “Tree of Thoughts” (ToT), a paper by DeepMind and Princeton researchers that generalizes over CoT.
(https://ss8319.github.io/_posts/ToT.png)
ToT is a sophisticated framework designed to enhance the problem-solving capabilities of LLMs by structuring their reasoning in a manner analogous to human cognitive processes. The framework is composed of four key components:
Thought decomposition: The ToT framework explicitly breaks a problem into smaller, manageable steps called thoughts, which are pieced together to form a solution. Each thought should be the right size—not too large to handle or too small to be useful. For example, if you’re planning a trip, a thought might involve deciding on a travel destination first, then choosing the best mode of transportation and finally picking a place to stay. In a mathematical problem, a thought might be a single equation line or a concise concept explanation. This way, the problem is broken down into key steps that are easy to tackle and evaluate individually. The decomposition depends on the nature of the problem, making sure that thoughts are both significant and feasible for evaluation.
- Thought generation: After defining what constitutes a thought, the next step is to determine how these thoughts are generated. The framework proposes two primary techniques.
- Sampling: This technique involves generating several thoughts independently by using the same prompt. It works best when the thought space is rich and diverse, as independently generated thoughts are less likely to be duplicated. For example, in creative writing, multiple independent plot ideas might be generated.
- Proposing: This technique sequentially generates thoughts using a “propose prompt.” Each thought is built upon the previous one, which helps avoid duplication in more constrained thought spaces. For example, in logical problem-solving, each step builds on the previous one to help ensure consistency and progress.
- State evaluation: Once thoughts are generated, they must be evaluated to help ensure progress toward a solution. The framework employs 2 strategies for this purpose:
- Value: This strategy involves assigning a scalar value (for example, a rating from 1-10) or a classification (for example, sure, likely or impossible) to each state. This helps indicate the value’s quality or likelihood of leading to a solution. This method allows for a quantitative assessment of each thought’s potential.
- Vote: This strategy compares different solutions and selects the most promising one. Voting is particularly useful for tasks where the quality of a solution is subjective or hard to quantify, such as in creative writing or strategic planning. Multiple evaluations combine to determine the best path forward.
- Search algorithm: The final component involves the search algorithm used to navigate through the solution space. The framework typically employs 2 fundamental algorithms:
- Breadth-first search (BFS): This algorithm explores all possible branches at each level before moving deeper into the tree. It makes sure that all potential solutions are considered equally, making it useful for problems where the shortest path or shallowest solution is preferred. For example, in a puzzle game, BFS would check all immediate moves before considering subsequent ones.
- Depth-first search (DFS): This algorithm explores one branch deeply before backtracking to explore other branches. It allows for a thorough examination of each potential solution path, making it useful for problems requiring detailed exploration of each option. For example, in solving a complex logic problem, DFS would follow a single hypothesis deeply, checking its validity before considering alternatives.
By integrating these components, the ToT framework mimics human problem-solving by systematically considering multiple solutions and discarding the ones that are found incorrect.