-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[Feature] Improved MultiChainComparison #8088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Labels
enhancement
New feature or request
Comments
@xaviermehaut Thank for reporting the feature request! would you mind explaining more about your idea in the description, or opening a draft PR? zip file is dangerous so we generally don't open it. |
Thanks for your answer ..
Here are the the two files (one for the module and one for the examples in the zip).
# --- START OF FILE improved_multi_chain_comparison.py (Updated Docstring) ---
from dspy.signatures import InputField, OutputField
from dspy.signatures.signature import ensure_signature
from dspy.primitives.program import Module
from dspy.predict.predict import Predict
class ImprovedMultiChainComparison(Module):
"""
An improved version of MultiChainComparison that structures inputs more clearly
and uses a more directive prompt incorporating a structured reasoning step
before final synthesis and correction. This explicit two-phase process,
akin to patterns like ReAct (Reason+Act), often leads to more reliable
refinement compared to less structured approaches.
Goal: To generate a refined reasoning process and final answer by first
explicitly analyzing multiple initial attempts (Reason phase) and then
synthesizing a corrected output based on that analysis (Act phase).
This approach aims to improve results even with less powerful LLMs.
Takes M completions (each expected to contain a 'rationale' or 'reasoning' field,
and the final answer field specified in the original signature) and asks a final
LLM call to perform two key steps within a single prediction:
1. **Analyze (Reason Phase):** Explicitly compare the provided attempts,
identifying similarities, differences, and potential errors.
2. **Synthesize/Correct (Act Phase):** Generate a corrected, step-by-step
reasoning process and provide the final answer fields based *on the
preceding analysis*.
"""
def __init__(self, signature, M=3, temperature=0.7,
synthesis_instructions=None, # Optional: Override default synthesis prompt
**config):
"""
Args:
signature (dspy.Signature): The signature defining the overall task (inputs and final outputs).
M (int, optional): The number of comparison attempts to generate/expect. Defaults to 3.
temperature (float, optional): Temperature for the final synthesis LLM call. Defaults to 0.7.
Consider using a lower temperature (e.g., 0.1-0.5) for more focused synthesis.
synthesis_instructions (str, optional): A custom prompt prefix for the analysis and
corrected reasoning step. If None, uses a default
structured prompt. Defaults to None.
**config: Additional configuration for the dspy.Predict module.
"""
super().__init__()
self.M = M
# 1. Ensure we have a manipulable signature object
original_signature = ensure_signature(signature)
# Identify the last output field key from the original signature
# This key will be used to extract the 'answer' from each completion in forward()
*_, self.last_key = original_signature.output_fields.keys()
# Define the default structured instructions for the synthesis step
default_instructions = (
"You are given multiple attempts to solve a problem. Your task is to analyze these attempts and produce a final, corrected reasoning and answer.\n\n"
"--- Analysis of Attempts ---\n"
"1. Carefully compare the reasoning steps provided in the attempts below.\n"
"2. Identify key similarities and differences in their approaches.\n"
"3. Point out any potential errors, inconsistencies, or missing steps in the attempts.\n\n"
"--- Corrected Reasoning ---\n"
"Based on your analysis, provide the most accurate and complete step-by-step reasoning to solve the problem correctly.\n"
"Reasoning:" # The LLM will continue generating the reasoning here
)
# Use custom instructions if provided, otherwise use the default
final_synthesis_instructions = synthesis_instructions or default_instructions
# 2. Start building the new signature for the synthesis step
# using the object returned by ensure_signature
synthesis_signature = original_signature
# 3. Prepend the new output field for analysis and corrected reasoning
# This returns a NEW modified signature
synthesis_signature = synthesis_signature.prepend(
"analysis_and_corrected_reasoning",
OutputField(
prefix=final_synthesis_instructions,
desc="${analysis and corrected reasoning steps}",
),
)
# 4. Append input fields for each student attempt
# Each call to append also returns a NEW modified signature
for idx in range(M):
synthesis_signature = synthesis_signature.append(
f"attempt_{idx + 1}", # Use a simpler key name
InputField(
prefix=f"--- Attempt #{idx + 1} ---",
desc="${reasoning attempt}", # Description for clarity, value passed in forward
format=str, # Ensure the value is treated as a string
),
)
# Initialize the predictor with the fully constructed synthesis signature
self.predict = Predict(synthesis_signature, temperature=temperature, **config)
def forward(self, completions, **kwargs):
"""
Processes M completions to generate a synthesized, corrected output.
Args:
completions (list[dspy.Prediction | dict]): A list containing M completions.
Each completion must have a 'rationale' or 'reasoning' field, and
the final answer field identified by self.last_key.
**kwargs: Additional input fields required by the original signature.
Returns:
dspy.Prediction: The final prediction object containing the analysis,
corrected reasoning, and the final answer fields from the original signature.
"""
attempts_structured = []
for idx, c in enumerate(completions):
# Extract rationale/reasoning (handle both keys for flexibility)
rationale = c.get('rationale', c.get('reasoning', '')).strip()
# Extract final answer using the identified last key
answer = str(c.get(self.last_key, '')).strip()
# Use a clearer, structured format for each attempt passed to the LLM
attempt_text = (
f"Reasoning: {rationale}\nPrediction: {answer}" # Compact format
)
attempts_structured.append(attempt_text)
# Ensure the number of provided completions matches the expected M
if len(attempts_structured) != self.M:
raise ValueError(
f"The number of provided completions ({len(attempts_structured)}) "
f"does not match the expected number M ({self.M}). "
f"Please ensure the input list 'completions' has exactly M elements."
)
# Prepare keyword arguments for the final prediction call
predict_kwargs = {
**{ # Pass the structured attempts as input fields
f"attempt_{idx + 1}": attempt
for idx, attempt in enumerate(attempts_structured)
},
**kwargs, # Pass through any other necessary inputs from the original signature
}
# Call the final predictor
return self.predict(**predict_kwargs)
# --- END OF FILE improved_multi_chain_comparison.py ---
# --- START OF FILE example_diverse_tasks_comparison.py ---
import os
import dspy
from dotenv import load_dotenv
# Import the two comparison modules
from dspy import MultiChainComparison # The original one
# Make sure the path to the improved module is correct
try:
from dspy_addons.multi_chain_comparison.multi_chain_comparison_improved import ImprovedMultiChainComparison
except ImportError:
# Fallback if the addon structure isn't used, assuming it's in the same directory
try:
from improved_multi_chain_comparison import ImprovedMultiChainComparison
except ImportError:
print("ERROR: Could not import ImprovedMultiChainComparison. Ensure the file is accessible.")
exit()
"""
Module: example_diverse_tasks_comparison.py
Compares `MultiChainComparison` (original) and `ImprovedMultiChainComparison`
on a SERIES of examples covering DIVERSE TASKS (different signatures)
to evaluate their behavior when facing different types of reasoning/generation.
How it works:
1. Defines SEVERAL `Signature` classes for various tasks.
2. Defines an `examples` list where each element contains:
- `signature_class`: The signature class to use.
- `input_data`: A dictionary with the input data for that signature.
3. For EACH example:
a. Re-initializes the modules (CoT, Original Comparison, Improved Comparison)
WITH THE SPECIFIC SIGNATURE for the example.
b. Manually generates multiple attempts via CoT.
c. Calls `ImprovedMultiChainComparison`.
d. Calls `MultiChainComparison` (original).
e. Displays the comparative results.
"""
load_dotenv()
# --- LLM Configuration ---
# Choose and configure your LLM
# Example using Ollama Gemma 3
llm = dspy.LM(
model='ollama_chat/gemma3:27b', # Adjust model name if needed
api_base=os.getenv("API_URL", "http://localhost:11434"), # Default if not set
temperature=0.3, # Temperature for generating initial attempts (adjust as needed)
max_tokens=8192,
top_p=0.9,
top_k=64,
model_type='chat',
cache=False # Disable cache for testing diversity if needed
)
# Example using Ollama Llama 3
# llm = dspy.LM(
# model='ollama_chat/llama3',
# api_base=os.getenv("API_URL", "http://localhost:11434"),
# temperature=0.3,
# max_tokens=8192,
# model_type='chat'
# )
# Example using Gemini
# google_api_key = os.getenv("GOOGLE_API_KEY")
# if not google_api_key: print("ERROR: GOOGLE_API_KEY missing"); exit()
# os.environ["GOOGLE_API_KEY"] = google_api_key
# llm = dspy.LM(model='gemini-1.5-flash', temperature=0.3, max_tokens=1500)
dspy.configure(lm=llm)
# --- Signature Definitions for Diverse Tasks ---
class SimpleExplanation(dspy.Signature):
"""Explains a concept simply for a specific audience."""
concept_name = dspy.InputField(desc="The concept to explain.")
target_audience = dspy.InputField(desc="The target audience for the explanation.")
explanation = dspy.OutputField(desc="A clear, simple explanation tailored to the audience.")
class SolveMathProblem(dspy.Signature):
"""Solves a simple math problem step by step."""
problem_statement = dspy.InputField(desc="The statement of the math problem.")
reasoning_steps = dspy.OutputField(desc="The detailed reasoning steps to reach the solution.")
final_answer = dspy.OutputField(desc="The numerical or final answer to the problem.")
class SummarizeText(dspy.Signature):
"""Summarizes a given text respecting an approximate length."""
text_to_summarize = dspy.InputField(desc="The original text to summarize.")
desired_length = dspy.InputField(desc="Desired summary length (e.g., 'one sentence', 'about 50 words').")
summary = dspy.OutputField(desc="The concise summary of the text.")
class WriteShortPoem(dspy.Signature):
"""Writes a short poem on a given theme."""
theme = dspy.InputField(desc="The theme or subject of the poem.")
style_suggestion = dspy.InputField(desc="Suggestion for style or tone (e.g., 'joyful', 'melancholic', 'haiku').")
poem = dspy.OutputField(desc="The generated poem.")
# Note: The concept of 'corrected reasoning' is less relevant here,
# but let's see how the modules handle it.
class EvaluateProsCons(dspy.Signature):
"""Evaluates the pros and cons of a given topic."""
topic = dspy.InputField(desc="The topic to evaluate.")
pros = dspy.OutputField(desc="List of the main advantages or positive points.")
cons = dspy.OutputField(desc="List of the main disadvantages or negative points.")
balanced_conclusion = dspy.OutputField(desc="A nuanced conclusion weighing the pros and cons.")
# --- List of Examples to Process (with signatures and data) ---
examples = [
{
"signature_class": SimpleExplanation,
"input_data": {
"concept_name": "Photosynthesis",
"target_audience": "an 8-year-old child"
}
},
{
"signature_class": SolveMathProblem,
"input_data": {
"problem_statement": "John has 5 apples. He buys 3 baskets each containing 4 apples. How many apples does John have in total?"
}
},
{
"signature_class": SummarizeText,
"input_data": {
"text_to_summarize": "Solar energy is a renewable energy source that comes from the sun. It can be captured by photovoltaic panels to produce electricity or by thermal collectors to heat water. It is a clean alternative to fossil fuels, helping to reduce greenhouse gas emissions and energy dependence. Its cost has dropped considerably in recent years, making it increasingly competitive.",
"desired_length": "two sentences maximum"
}
},
{
"signature_class": WriteShortPoem,
"input_data": {
"theme": "Autumn rain on a window",
"style_suggestion": "calm and a bit melancholic"
}
},
{
"signature_class": EvaluateProsCons,
"input_data": {
"topic": "Widespread remote work for office employees"
}
}
]
# --- Global Parameters ---
num_attempts = 3 # Number of attempts to generate per example
# --- Function to process an example (with module re-initialization) ---
def run_comparison_for_example(example_config, n_attempts):
"""Generates attempts and runs both comparison modules for a specific example."""
signature_class = example_config['signature_class']
input_data = example_config['input_data']
# Dynamically get input and output keys from the signature
input_keys = list(signature_class.input_fields.keys())
output_keys = list(signature_class.output_fields.keys())
print(f"\n\n{'=' * 20} NEXT EXAMPLE ({signature_class.__name__}) {'=' * 20}")
for key, value in input_data.items():
print(f"Input [{key}]: {value}")
print(f"{'=' * (40 + len(f' NEXT EXAMPLE ({signature_class.__name__}) '))}\n")
# --- Re-initialize modules FOR THIS SIGNATURE ---
# This is crucial because DSPy modules are tied to the signature they are initialized with
print(f"--- Initializing modules for signature: {signature_class.__name__} ---")
cot_generator = dspy.ChainOfThought(signature_class, n=1)
comparison_module_improved = ImprovedMultiChainComparison(
signature=signature_class, M=n_attempts, temperature=0.1 # Low temp for comparison
)
comparison_module_original = MultiChainComparison(
signature=signature_class, M=n_attempts, temperature=0.1 # Low temp for comparison
)
# --- End Initialization ---
# 1. Generate attempts
completions_list = []
print(f"--- Generating {n_attempts} attempts ---")
for i in range(n_attempts):
print(f" Generating attempt {i+1}/{n_attempts}...")
try:
# Pass input_data dynamically using **kwargs
prediction = cot_generator(**input_data)
if prediction.completions:
completions_list.append(prediction.completions[0])
else:
print(f" WARNING: Attempt {i+1} returned no completion.")
except Exception as e:
print(f" ERROR during attempt {i+1} generation: {e}")
actual_attempts = len(completions_list)
# Ensure exactly M attempts were generated before proceeding
if actual_attempts != n_attempts:
print(f"\nWARNING: Generated {actual_attempts} completions, but expected M={n_attempts}. Skipping comparison for this example.")
return # Skip if the number of attempts is not exactly M
# Display generated attempts (optional, showing only the first output field for brevity)
print("\n--- Generated Attempts (Displaying first output field) ---")
first_output_key = output_keys[0]
for i, comp in enumerate(completions_list):
print(f"--- Attempt {i+1} ({first_output_key}) ---")
print(f" {getattr(comp, first_output_key, 'N/A')}")
print("-" * 20)
# 2. Execute ImprovedMultiChainComparison
print(f"\n--- Calling ImprovedMultiChainComparison (expects M={comparison_module_improved.M}) ---")
try:
final_pred_improved = comparison_module_improved(
completions=completions_list, **input_data
)
print("\n--- Result [ImprovedMultiChainComparison] ---")
print(f"Analysis & Corrected Reasoning:\n{getattr(final_pred_improved, 'analysis_and_corrected_reasoning', 'N/A')}")
print("-" * 30)
# Display all final output fields defined in the signature
for key in output_keys:
print(f"Final Output [{key}]:\n{getattr(final_pred_improved, key, 'N/A')}")
print("-" * 30)
except Exception as e:
print(f"ERROR [ImprovedMultiChainComparison]: {e}")
# 3. Execute MultiChainComparison (Original)
print(f"\n--- Calling MultiChainComparison Original (expects M={comparison_module_original.M}) ---")
try:
final_pred_original = comparison_module_original(
completions=completions_list, **input_data
)
print("\n--- Result [MultiChainComparison Original] ---")
print(f"Original Rationale:\n{getattr(final_pred_original, 'rationale', 'N/A')}")
print("-" * 30)
# Display all final output fields defined in the signature
for key in output_keys:
print(f"Final Output [{key}]:\n{getattr(final_pred_original, key, 'N/A')}")
print("-" * 30)
except Exception as e:
print(f"ERROR [MultiChainComparison Original]: {e}")
# --- Main Loop ---
if __name__ == "__main__":
if not examples:
print("No examples defined in the 'examples' list.")
else:
print(f"Starting comparison on {len(examples)} diverse examples...")
for example_config in examples:
run_comparison_for_example(
example_config=example_config,
n_attempts=num_attempts
)
print("\n\n--- Finished all examples ---")
# --- END OF FILE example_diverse_tasks_comparison.py ---
…________________________________
De : Chen Qian ***@***.***>
Envoyé : lundi 21 avril 2025 01:26
À : stanfordnlp/dspy ***@***.***>
Cc : Xavier MÉHAUT ***@***.***>; Mention ***@***.***>
Objet : Re: [stanfordnlp/dspy] [Feature] Improved MultiChainComparison (Issue #8088)
@xaviermehaut<https://github.com/xaviermehaut> Thank for reporting the feature request! would you mind explaining more about your idea in the description, or opening a draft PR? zip file is dangerous so we generally don't open it.
—
Reply to this email directly, view it on GitHub<#8088 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/BJINXEU5TLN7BWHIOP3TIYL22QUJJAVCNFSM6AAAAAB3QDSPI6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDQMJXGM4DMMRYGU>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
[https://avatars.githubusercontent.com/u/22925031?s=20&v=4]chenmoneygithub left a comment (stanfordnlp/dspy#8088)<#8088 (comment)>
@xaviermehaut<https://github.com/xaviermehaut> Thank for reporting the feature request! would you mind explaining more about your idea in the description, or opening a draft PR? zip file is dangerous so we generally don't open it.
—
Reply to this email directly, view it on GitHub<#8088 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/BJINXEU5TLN7BWHIOP3TIYL22QUJJAVCNFSM6AAAAAB3QDSPI6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDQMJXGM4DMMRYGU>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
Ce courrier électronique et, le cas échéant, les pièces jointes sont confidentiels et établis à l'attention exclusive de ses destinataires. Sa reproduction, totale ou partielle, ou toute autre utilisation sont interdites sans autorisation préalable de Sibylone. L’intégrité de ce courrier n’étant pas assurée sur Internet, Sibylone décline toute responsabilité au titre de son contenu, s'il a été altéré, déformé ou falsifié. Si vous le recevez par erreur, merci de bien vouloir le détruire et en informer son expéditeur.
Pensez à l’environnement. N’imprimez ce courriel que si vous en avez vraiment besoin.
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What feature would you like to see?
MultiChainComparison could be a very useful component but its behaviour seems too unpredictable and inaccurate, especially with SLM or small LLM like the ones we can find on Ollama.
I propose an improvement by rewriting the internal signature in an internal ReACT spirit... It seems on multiple examples that the answer are more precise and predictable...
my to cents
ps: I added the code of the imporvement along with an examples module FYI
Would you like to contribute?
Additional Context
multi_chain_comparison.zip
The text was updated successfully, but these errors were encountered: