from typing import Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import dspy
import os
from dotenv import load_dotenv
from dspy.teleprompt import LabeledFewShot

# === Load secrets ===
load_dotenv(dotenv_path="/home/azureuser/microlearn/backend/prompt_opt_dspy/.env")

azure_api_key = os.getenv("AZURE_API_KEY")
azure_api_host = os.getenv("AZURE_API_HOST")
azure_api_version = os.getenv("AZURE_API_VERSION")

# === Configure Azure OpenAI LM ===
lm = dspy.LM(
    "azure/Csqr-gpt-4o-mini",
    api_key=azure_api_key,
    api_base=azure_api_host,
    api_version=azure_api_version,
    temperature=0.0,
)
dspy.configure(lm=lm)


# === DSPy Signature ===
class TopicValidationSignature(dspy.Signature):
    """
    You are a smart validator that checks if a given input represents a meaningful, real-world topic suitable for research, learning, or exploration in fields like science, technology, business, or education.

    A valid topic should:
    - Be widely recognized or commonly discussed 
    - Contain meaningful words (not just random characters or test strings)
    - Be at least 2–3 characters long

    An invalid topic is:
    - Nonsensical 
    - Too short or made-up without real context

    Output format:
        Based on this, evaluate the input and respond eiter:
        "Valid Topic" or "Invalid Topic" only. 
        Should not include any other comments.
    Input:
    - topic: {topic}
    
    """
    topic: str = dspy.InputField()
    valid_topic: str = dspy.OutputField(desc="Either 'Valid Topic' or 'Invalid Topic'")


# === DSPy Module ===
class TopicValidationModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predictor = dspy.Predict(TopicValidationSignature)

        # --- Training Data ---
        trainset = [
            dspy.Example(topic="Artificial Intelligence", valid_topic="Valid Topic").with_inputs("topic"),
            dspy.Example(topic="qwertyuiop", valid_topic="Invalid Topic").with_inputs("topic"),
            dspy.Example(topic="Climate Change", valid_topic="Valid Topic").with_inputs("topic"),
            dspy.Example(topic="asdf", valid_topic="Invalid Topic").with_inputs("topic"),
        ]

        # # --- Optimizer ---
        # optimizer = LabeledFewShot(k=3)
        # self.optimized_predictor = optimizer.compile(
        #     predictor,
        #     trainset=trainset
        # )


        # --- Metric: must return only "Valid Topic" or "Invalid Topic" ---
        def validation_metric(gold, pred, trace=None):
            output = getattr(pred, "validation", "").strip()
            return int(output in ["Valid Topic", "Invalid Topic"])

        optimizer = dspy.BootstrapFewShot(metric=validation_metric)
        self.optimized_predictor = optimizer.compile(
            student=self.predictor,
            trainset=trainset
        )

    def forward(self, topic: str):
        return self.optimized_predictor(topic=topic)


# === Tool Input Schema ===
class TopicValidationInput(BaseModel):
    """Input schema for TopicValidationTool."""
    topic: str = Field(..., description="The topic string to validate")


# === Custom DSPy Tool ===
class TopicValidationTool(BaseTool):
    name: str = "Topic Validation Tool"
    description: str = (
        "Validates if a given topic is meaningful. Returns 'Valid Topic' or 'Invalid Topic'."
    )
    args_schema: Type[BaseModel] = TopicValidationInput

    def _run(self, topic: str) -> str:
        validator = TopicValidationModule()
        result = validator.forward(topic=topic)
        return result.valid_topic


# # === CLI Test Usage ===
# if __name__ == "__main__":
#     tool = TopicValidationTool()
#     result = tool._run(
#     # topic = "abc123"
#     topic = "Quantum Entanglement"

#     )
#     print(result)








# from typing import Type
# from crewai.tools import BaseTool
# from pydantic import BaseModel, Field
# import dspy
# import os
# from dotenv import load_dotenv
# from dspy.teleprompt import LabeledFewShot

# # === Load secrets ===
# load_dotenv(dotenv_path="/home/azureuser/microlearn/backend/prompt_opt_dspy/.env")

# azure_api_key = os.getenv("AZURE_API_KEY")
# azure_api_host = os.getenv("AZURE_API_HOST")
# azure_api_version = os.getenv("AZURE_API_VERSION")

# # === Configure Azure OpenAI LM ===
# lm = dspy.LM(
#     "azure/Csqr-gpt-4o-mini",
#     api_key=azure_api_key,
#     api_base=azure_api_host,
#     api_version=azure_api_version,
#     temperature=0.0,
# )
# dspy.configure(lm=lm)


# # === DSPy Signature ===
# class TopicValidationSignature(dspy.Signature):
#     """Determine if a given topic is a valid and meaningful topic."""
#     topic: str = dspy.InputField()
#     valid_topic: str = dspy.OutputField(desc="Either 'Valid Topic' or 'Invalid Topic'")


# # === DSPy Module ===
# class TopicValidationModule(dspy.Module):
#     def __init__(self):
#         super().__init__()
#         predictor = dspy.Predict(TopicValidationSignature)

#         # --- Training Data ---
#         training_data = [
#             dspy.Example(topic="Artificial Intelligence", valid_topic="Valid Topic").with_inputs("topic"),
#             dspy.Example(topic="qwertyuiop", valid_topic="Invalid Topic").with_inputs("topic"),
#             dspy.Example(topic="Climate Change", valid_topic="Valid Topic").with_inputs("topic"),
#             dspy.Example(topic="asdf", valid_topic="Invalid Topic").with_inputs("topic"),
#             dspy.Example(topic="Blockchain", valid_topic="Valid Topic").with_inputs("topic"),
#         ]

#         # --- Optimizer ---
#         optimizer = LabeledFewShot(k=3)
#         self.optimized_predictor = optimizer.compile(
#             predictor,
#             trainset=training_data
#         )

#     def forward(self, topic: str):
#         return self.optimized_predictor(topic=topic)


# # === Tool Input Schema ===
# class TopicValidationInput(BaseModel):
#     """Input schema for TopicValidationTool."""
#     topic: str = Field(..., description="The topic string to validate")


# # === Custom DSPy Tool ===
# class TopicValidationTool(BaseTool):
#     name: str = "Topic Validation Tool"
#     description: str = (
#         "Validates if a given topic is meaningful. Returns 'Valid Topic' or 'Invalid Topic'."
#     )
#     args_schema: Type[BaseModel] = TopicValidationInput

#     def _run(self, topic: str) -> str:
#         validator = TopicValidationModule()
#         result = validator.forward(topic=topic)
#         return result.valid_topic


# # # === CLI Test Usage ===
# # if __name__ == "__main__":
# #     tool = TopicValidationTool()
# #     print("=== Testing Topic Validation ===")
# #     output1 = tool._run(topic="Cyber Security")
# #     print(f"The topic 'Cyber Security' is a: {output1}")
# #     output2 = tool._run(topic="lkjghfdsa")
# #     print(f"The topic 'lkjghfdsa' is a: {output2}")
