import re
from crewai import TaskOutput, Task
from typing import Optional, Tuple, Union

class GuardRailFn:
    def extract_and_sum_durations(text):
            total_minutes = 0.0
            total_hours = 0.0

            # Find all Duration lines
            duration_lines = re.findall(r'###\s*\*{0,2}Duration\*{0,2}\s*(.*)', text)

            for line in duration_lines:
                # Extract first occurrence of number + unit (ignoring brackets)
                match = re.search(r'(\d+(?:\.\d+)?)\s*(hours|hour|minutes|minute)', line)
                if match:
                    value = float(match.group(1))
                    unit = match.group(2)
                    if 'minute' in unit:
                        total_minutes += value
                    elif 'hour' in unit:
                        total_hours += value

            return total_minutes, total_hours
        
    def parse_available_time(time_str):
        hours = 0.0
        minutes = 0.0

        # Find all number + unit in the string
        matches = re.findall(r'(\d+(?:\.\d+)?)\s*(hours|hour|minutes|minute)', time_str)
        for value, unit in matches:
            value = float(value)
            if 'hour' in unit:
                print(f"Hour unit detected : {value}")
                hours += value
            elif 'minute' in unit:
                print(f"Minute unit detected : {value}")
                minutes += value

        return minutes, hours


    
    def validate_content(result: TaskOutput, available_time: str) -> Tuple[bool, TaskOutput]:
        """Validate content meets requirements."""
        try:
            # print(f"The task output is :{result}")
            print(f"The task outputs type is : {type(result)}")
            total_minutes, total_hours = GuardRailFn.extract_and_sum_durations(result.raw)
            print(f"The total minutes is : {total_minutes} and total hours is : {total_hours}")
            input_minutes, input_hours = GuardRailFn.parse_available_time(available_time)
            print(f"The available time minutes is : {input_minutes} and available time hours is : {input_hours}")
            task_total_minutes = total_minutes + (total_hours * 60)
            available_total_minutes = input_minutes + (input_hours * 60)
            time_difference_minutes = available_total_minutes - task_total_minutes
            if time_difference_minutes > 0:
                return (False, f"The user journey duration is less than the available time by {time_difference_minutes} minutes.")
            elif time_difference_minutes < 0:
                return (False, f"The user journey duration exceeds the available time by {abs(time_difference_minutes)} minutes.")
            else:
                return (True, result)


        except Exception as e:
            error_result = TaskOutput(output=f"Unexpected error during validation: {e}")
            return (False, error_result)
        
    def get_validate_content_guardrail(available_time: str):
        def validate(result: TaskOutput) -> Tuple[bool, Union[str, TaskOutput]]:
            return GuardRailFn.validate_content(result, available_time)
        return validate

    def validate_op_defaults(result: TaskOutput) -> Tuple[bool, TaskOutput]:
        try:
                print("Inside guardrail function")
                matches = re.findall(r"##\s*([\w\s]+?)\s*-\s*(\d+)%",result.raw)

                print("match found")
                # Convert to dictionary
                skills_dict =  {skill.strip(): int(percent) for skill, percent in matches}
                return (True,result)
        except Exception as e:
            print(e)
            # raise RestructureOutputError
            return (False, '''The format of the output is not correct. Follow this format -  
            ## SQL - ....% 
            ## Viz - ....% 
            ## Stats - ....%
            ## Comm - ....%
            ## Story - ....%''')

    def validate_defaults_op():
        def validate(result: TaskOutput) -> Tuple[bool, Union[str, TaskOutput]]:
            return GuardRailFn.validate_op_defaults(result)
        return validate
            
