from typing import Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import dspy
import os
from dotenv import load_dotenv


load_dotenv(dotenv_path="/home/azureuser/microlearn/backend/prompt_opt_dspy/.env")

azure_api_key = os.getenv("AZURE_API_KEY")
azure_api_host = os.getenv("AZURE_API_HOST")
azure_deployment_id = os.getenv("AZURE_DEPLOYMENT_ID")
azure_api_version = os.getenv("AZURE_API_VERSION")

# === Configure Azure OpenAI LM ===
lm = dspy.LM(
    "azure/Csqr-gpt-4o-mini",
    api_key=azure_api_key,
    api_base=azure_api_host,
    api_version=azure_api_version,
    temperature=0.0,
)
dspy.configure(lm=lm)


# === DSPy Signature ===
class DefaultsEstimatorSignature(dspy.Signature):
    """
    User's default skill estimator   
    Analyse the user
    You have the ability to analyze a user's profile, context, and prior knowledge to estimate their default proficiency levels across various skill categories. 
    You are an expert skill analyser based on the context.
    Your task is to analyse the user profile and determine the approximate distribution (in percentages) across the following five skill categories:
    1. SQL (Out of 100%) 
    2. Data Visualization (Viz) (Out of 100%)
    3. Statistical Analysis (Stats)  (Out of 100%)
    4. Business Communication (Comm) (Out of 100%)
    5. Data Storytelling (Story) (Out of 100%)
    The percentages may overlap and should reflect the level of each skill for the user in their current role.

    User Context:
    - Role: {Job_Title}
    - Topic: {topic}
    - Experience with the topic - {topic}: {Experience}
    - Technical Preference: {Level}

    ## SQL - ....%
    ## Viz - ....%
    ## Stats - ....%
    ## Comm - ....%
    ## Story - ....%
    Should not include any other comments.

    """
    Job_Title: str = dspy.InputField()
    topic: str = dspy.InputField()
    Experience: str = dspy.InputField()
    Level: str = dspy.InputField()
    skill_distribution: str = dspy.OutputField(
        desc="Percentages across SQL, Viz, Stats, Comm, Story. Format only with headings."
    )

# === DSPy Module ===
class DefaultsEstimatorModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.restructure_predictor = dspy.Predict(DefaultsEstimatorSignature)
        predictor = dspy.Predict(DefaultsEstimatorSignature)
# --- Two Example Training Shots ---
        trainset = [
                    dspy.Example(
                        Job_Title="Data Analyst",
                        topic="SQL",
                        Experience="2 years",
                        Level="Intermediate",
                        skill_distribution="""
                        ## SQL - 70%
                        ## Viz - 60%
                        ## Stats - 50%
                        ## Comm - 55%
                        ## Story - 45%
                        """
                    ).with_inputs("Job_Title", "topic", "Experience", "Level"),

                    dspy.Example(
                        Job_Title="Business Analyst",
                        topic="Data Visualization",
                        Experience="3 years",
                        Level="Beginner",
                        skill_distribution="""
                        ## SQL - 40%
                        ## Viz - 75%
                        ## Stats - 50%
                        ## Comm - 70%
                        ## Story - 65%
                        """
                    ).with_inputs("Job_Title", "topic", "Experience", "Level"),
                ]
                # --- Custom Metric: ensure all 5 categories are present ---
        def estimator_metric(gold, pred, trace):
            out = pred.skill_distribution.lower()
            return int(
                all(skill in out for skill in ["sql", "viz", "stats", "comm", "story"])
            )

        # Train with BootstrapFewShot
        optimizer = dspy.BootstrapFewShot(metric=estimator_metric)
        optimizer.compile(predictor, trainset=trainset)

        self.defaults_estimator_predictor = predictor
        print("=== Optimized Defaults Estimator Predictor ===")
        print(predictor)

    def forward(self, Job_Title, topic, Experience, Level):
        return self.defaults_estimator_predictor(
            Job_Title=Job_Title,
            topic=topic,
            Experience=Experience,
            Level=Level
        )

# === Tool Input Schema ===
class DefaultsEstimatorInput(BaseModel):
    """Input schema for DefaultsEstimatorTool."""
    Job_Title: str = Field(..., description="User's job title")
    topic: str = Field(..., description="Topic of interest or focus area")
    Experience: str = Field(..., description="User's experience with the topic")
    Level: str = Field(..., description="Learning level (Beginner, Intermediate, Advanced)")


# === Custom DSPy Tool ===
class DefaultsEstimatorTool(BaseTool):
    name: str = "Defaults Estimator Tool"
    description: str = (
        "Analyzes a user's profile and context to estimate their default proficiency levels "
        "across SQL, Data Visualization, Statistical Analysis, Business Communication, and Data Storytelling."
    )
    args_schema: Type[BaseModel] = DefaultsEstimatorInput

    def _run(self, Job_Title: str, topic: str, Experience: str, Level: str) -> str:
        estimator = DefaultsEstimatorModule()
        result = estimator.forward(Job_Title, topic, Experience, Level)
        return result.skill_distribution


# if __name__ == "__main__":
#     tool = DefaultsEstimatorTool()
#     output = tool.run(
#         Job_Title="Data Scientist",
#         topic="Machine Learning",
#         Experience="4 years",
#         Level="Advanced"
#     )
#     print("=== Estimated Skills ===")
#     print(output)
