Skip to content

LangGraph Sample Project#

Objective#

  • Independent deployable services
  • Each agent can scale horizontally (e.g., analysis_service replicas)
  • You can version and deploy agents independently

  • Schema isolation

    • Each service defines its own Pydantic input/output
    • Supervisor does schema translation
  • Resilience

    • Supervisor can retry subgraph calls, add timeout handling
  • Observability

    • You can trace inter-agent calls via httpx middleware or OpenTelemetry
  • Extensible

    • Just add new agents (summarizer_service, retriever_service, etc.)
    • Supervisor graph can grow dynamically without coupling

Project Structure#

🎯 What's Included:

  • 3 Independent Services:
    • Research Service (Port 8081) - Handles research queries with validation, planning, gathering, and summarization
    • Analysis Service (Port 8082) - Extracts insights, generates recommendations, and creates analysis reports
    • Supervisor Service (Port 8080) - Orchestrates the entire workflow via REST API calls
  • Key Features:

    • ✅ Independent LangGraph workflows in each service
    • ✅ Shared error handler with per-node retry tracking
    • ✅ REST API communication between services
    • ✅ Parallel node support - no state conflicts
    • ✅ Docker deployment ready with docker-compose
    • ✅ Health checks for monitoring
    • ✅ Comprehensive error handling with automatic retries
    • ✅ Test scripts for validation
Microservice-style LangGraph architecture
agentic_system/
├── shared/
   └── error_handler.py          # Shared error handling logic
├── research_service/
   ├── main.py                    # FastAPI app
   ├── schema.py                  # Pydantic models
   ├── graph.py                   # LangGraph workflow
   ├── requirements.txt
   └── Dockerfile
├── analysis_service/
   ├── main.py                    # FastAPI app
   ├── schema.py                  # Pydantic models
   ├── graph.py                   # LangGraph workflow
   ├── requirements.txt
   └── Dockerfile
├── supervisor_service/
   ├── main.py                    # FastAPI app
   ├── schema.py                  # Pydantic models
   ├── graph.py                   # LangGraph workflow
   ├── requirements.txt
   └── Dockerfile
├── docker-compose.yml
├── run_services.sh
└── test_system.py

Services#

Research Service
  ├── research_service/
  ├── main.py                    # FastAPI app
  ├── schema.py                  # Pydantic models
  ├── graph.py                   # LangGraph workflow
codes: research_service/
schema.py
  from typing import List, Optionalfrom pydantic import BaseModel, Field
  from typing import Optional


  class ResearchState(BaseModel):
      query: str = ""
      research_plan: str = ""
      search_results: list[str] = Field(default_factory=list)
      summary: str = ""

      # Error handling
      error_messages: list[str] = Field(default_factory=list)
      retry_count: int = 0
      max_retries: int = 3
      failed_nodes: dict[str, int] = Field(default_factory=dict)

      class Config:
          arbitrary_types_allowed = True


  class ResearchRequest(BaseModel):
      query: str
      max_retries: int = 3


  class ResearchResponse(BaseModel):
      query: str
      research_plan: str
      search_results: list[str]
      summary: str
      error_messages: list[str]
      failed_nodes: dict[str, int]
      success: bool
graph.py
  import logging
  from typing import Optional

  from langgraph.graph import StateGraph, END
  from langchain_ollama import ChatOllama
  from langchain_core.messages import HumanMessage, SystemMessage

  from agentic_app.research_service.schema import ResearchState
  from agentic_app.shared.error_handler import handle_node_errors, create_universal_router

  logger = logging.getLogger(__name__)


  class ResearchNodes:
      def __init__(self, llm: Optional[ChatOllama] = None):
          self.llm = llm or ChatOllama(model="gpt-oss", temperature=0)

      @handle_node_errors("validate_query", "Failed to validate query")
      def validate_query(self, state: ResearchState) -> dict:
          logger.info(f"Validating query: {state.query}")

          if not state.query or len(state.query.strip()) < 5:
              raise ValueError("Query must be at least 5 characters long")

          return {}

      @handle_node_errors("create_plan", "Failed to create research plan")
      async def create_plan(self, state: ResearchState) -> dict:
          logger.info("Creating research plan")

          messages = [
              SystemMessage(content="Create a brief 3-step research plan."),
              HumanMessage(content=f"Create a research plan for: {state.query}")
          ]

          response = await self.llm.ainvoke(messages)

          return {
              "research_plan": response.content
          }

      @handle_node_errors("gather_info", "Failed to gather information")
      async def gather_info(self, state: ResearchState) -> dict:
          logger.info("Gathering information")

          # Simulate research gathering
          search_results = [
              f"Finding 1 about {state.query}",
              f"Finding 2 about {state.query}",
              f"Finding 3 about {state.query}",
          ]

          return {
              "search_results": search_results
          }

      @handle_node_errors("summarize", "Failed to summarize")
      async def summarize(self, state: ResearchState) -> dict:
          logger.info("Summarizing findings")

          findings = "\n".join(f"- {r}" for r in state.search_results)

          messages = [
              SystemMessage(content="Summarize these research findings concisely."),
              HumanMessage(content=f"Plan: {state.research_plan}\n\nFindings:\n{findings}")
          ]

          response = await self.llm.ainvoke(messages)

          return {
              "summary": response.content
          }


  def create_research_graph():
      nodes = ResearchNodes()
      workflow = StateGraph(ResearchState)

      workflow.add_node("validate_query", nodes.validate_query)
      workflow.add_node("create_plan", nodes.create_plan)
      workflow.add_node("gather_info", nodes.gather_info)
      workflow.add_node("summarize", nodes.summarize)

      workflow.set_entry_point("validate_query")

      workflow.add_conditional_edges(
          "validate_query",
          create_universal_router(next_node="create_plan", node_name="validate_query")
      )
      workflow.add_conditional_edges(
          "create_plan",
          create_universal_router(next_node="gather_info", node_name="create_plan")
      )
      workflow.add_conditional_edges(
          "gather_info",
          create_universal_router(next_node="summarize", node_name="gather_info")
      )
      workflow.add_conditional_edges(
          "summarize",
          create_universal_router(next_node=END, node_name="summarize")
      )

      return workflow.compile()
main.py
    from fastapi import FastAPI, HTTPException
    from fastapi.middleware.cors import CORSMiddleware
    import logging

    from agentic_app.analysis_service.graph import create_analysis_graph
    from agentic_app.analysis_service.schema import AnalysisState, AnalysisRequest, AnalysisResponse

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    app = FastAPI(title="Analysis Service", version="1.0.0")

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    analysis_graph = create_analysis_graph()


    @app.post("/analyze", response_model=AnalysisResponse)
    async def analyze_research(request: AnalysisRequest):
        """Analyze research summary and generate insights"""
        try:
            logger.info("Received analysis request")

            initial_state = AnalysisState(
                research_summary=request.research_summary,
                max_retries=request.max_retries
            )

            final_state = await analysis_graph.ainvoke(initial_state)

            return AnalysisResponse(
                insights=final_state.get("insights", []),
                recommendations=final_state.get("recommendations", []),
                final_analysis=final_state.get("final_analysis", ""),
                error_messages=final_state.get("error_messages", []),
                failed_nodes=final_state.get("failed_nodes", {}),
                success=len(final_state.get("error_messages", [])) == 0
            )

        except Exception as e:
            logger.error(f"Analysis failed: {str(e)}")
            raise HTTPException(status_code=500, detail=str(e))


    @app.get("/health")
    async def health():
        return {"status": "healthy", "service": "analysis"}

    if __name__ == '__main__':
        import uvicorn
        uvicorn.run(app, host="0.0.0.0", port=8082)
Analysis Service
 ├── analysis_service/
 ├── main.py                    # FastAPI app
 ├── schema.py                  # Pydantic models
 ├── graph.py    
codes: analysis_service/
schema.py
  from pydantic import BaseModel, Field, ConfigDict

  class AnalysisState(BaseModel):
      model_config = ConfigDict(arbitrary_types_allowed=True)
      research_summary: str = ""
      insights: list[str] = Field(default_factory=list)
      recommendations: list[str] = Field(default_factory=list)
      final_analysis: str = ""

      # Error handling
      error_messages: list[str] = Field(default_factory=list)
      retry_count: int = 0
      max_retries: int = 3
      failed_nodes: dict[str, int] = Field(default_factory=dict)


  class AnalysisRequest(BaseModel):
      research_summary: str
      max_retries: int = 3


  class AnalysisResponse(BaseModel):
      insights: list[str]
      recommendations: list[str]
      final_analysis: str
      error_messages: list[str]
      failed_nodes: dict[str, int]
      success: bool
graph.py
    import logging
    from typing import Optional

    from langgraph.graph import StateGraph, END
    from langchain_ollama import ChatOllama
    from langchain_core.messages import HumanMessage, SystemMessage

    from agentic_app.analysis_service.schema import AnalysisState
    from agentic_app.shared.error_handler import handle_node_errors, create_universal_router

    logger = logging.getLogger(__name__)


    class AnalysisNodes:
        def __init__(self, llm: Optional[ChatOllama] = None):
            self.llm = llm or ChatOllama(model="gpt-oss", temperature=0)

        @handle_node_errors("extract_insights", "Failed to extract insights")
        async def extract_insights(self, state: AnalysisState) -> dict:
            logger.info("Extracting insights")

            messages = [
                SystemMessage(content="Extract 3 key insights from this research."),
                HumanMessage(content=state.research_summary)
            ]

            response = await self.llm.ainvoke(messages)

            # Parse insights (simplified)
            insights = [line.strip() for line in response.content.split('\n') if line.strip()][:3]

            return {
                "insights": insights
            }

        @handle_node_errors("generate_recommendations", "Failed to generate recommendations")
        async def generate_recommendations(self, state: AnalysisState) -> dict:
            logger.info("Generating recommendations")

            insights_text = "\n".join(f"- {i}" for i in state.insights)

            messages = [
                SystemMessage(content="Generate 3 actionable recommendations based on these insights."),
                HumanMessage(content=insights_text)
            ]

            response = await self.llm.ainvoke(messages)

            recommendations = [line.strip() for line in response.content.split('\n') if line.strip()][:3]

            return {
                "recommendations": recommendations
            }

        @handle_node_errors("create_analysis", "Failed to create final analysis")
        async def create_analysis(self, state: AnalysisState) -> dict:
            logger.info("Creating final analysis")

            messages = [
                SystemMessage(content="Create a concise final analysis report."),
                HumanMessage(
                    content=f"Summary: {state.research_summary}\n\nInsights: {state.insights}\n\nRecommendations: {state.recommendations}")
            ]

            response = await self.llm.ainvoke(messages)

            return {
                "final_analysis": response.content
            }


    def create_analysis_graph():
        nodes = AnalysisNodes()
        workflow = StateGraph(AnalysisState)

        workflow.add_node("extract_insights", nodes.extract_insights)
        workflow.add_node("generate_recommendations", nodes.generate_recommendations)
        workflow.add_node("create_analysis", nodes.create_analysis)

        workflow.set_entry_point("extract_insights")

        workflow.add_conditional_edges(
            "extract_insights",
            create_universal_router(next_node="generate_recommendations", node_name="extract_insights")
        )
        workflow.add_conditional_edges(
            "generate_recommendations",
            create_universal_router(next_node="create_analysis", node_name="generate_recommendations")
        )
        workflow.add_conditional_edges(
            "create_analysis",
            create_universal_router(next_node=END, node_name="create_analysis")
        )

        return workflow.compile()
main.py
    from fastapi import FastAPI, HTTPException
    from fastapi.middleware.cors import CORSMiddleware
    import logging

    from agentic_app.analysis_service.graph import create_analysis_graph
    from agentic_app.analysis_service.schema import AnalysisState, AnalysisRequest, AnalysisResponse

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    app = FastAPI(title="Analysis Service", version="1.0.0")

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    analysis_graph = create_analysis_graph()


    @app.post("/analyze", response_model=AnalysisResponse)
    async def analyze_research(request: AnalysisRequest):
        """Analyze research summary and generate insights"""
        try:
            logger.info("Received analysis request")

            initial_state = AnalysisState(
                research_summary=request.research_summary,
                max_retries=request.max_retries
            )

            final_state = await analysis_graph.ainvoke(initial_state)

            return AnalysisResponse(
                insights=final_state.get("insights", []),
                recommendations=final_state.get("recommendations", []),
                final_analysis=final_state.get("final_analysis", ""),
                error_messages=final_state.get("error_messages", []),
                failed_nodes=final_state.get("failed_nodes", {}),
                success=len(final_state.get("error_messages", [])) == 0
            )

        except Exception as e:
            logger.error(f"Analysis failed: {str(e)}")
            raise HTTPException(status_code=500, detail=str(e))


    @app.get("/health")
    async def health():
        return {"status": "healthy", "service": "analysis"}

    if __name__ == '__main__':
        import uvicorn
        uvicorn.run(app, host="0.0.0.0", port=8082)
Supervisor Service
 ├── supervisor_service/
 ├── main.py                    # FastAPI app
 ├── schema.py                  # Pydantic models
 ├── graph.py                   # LangGraph workflow
codes: analysis_service/
schema.py
  from pydantic import BaseModel, Field
  from typing import Optional


  class SupervisorState(BaseModel):
      original_query: str = ""
      research_result: dict = Field(default_factory=dict)
      analysis_result: dict = Field(default_factory=dict)
      final_report: str = ""

      # Error handling
      error_messages: list[str] = Field(default_factory=list)
      retry_count: int = 0
      max_retries: int = 3
      failed_nodes: dict[str, int] = Field(default_factory=dict)

      class Config:
          arbitrary_types_allowed = True


  class SupervisorRequest(BaseModel):
      query: str
      max_retries: int = 3
      research_service_url: str = "http://localhost:8081"
      analysis_service_url: str = "http://localhost:8082"


  class SupervisorResponse(BaseModel):
      query: str
      research_summary: str
      analysis_report: str
      final_report: str
      error_messages: list[str]
      failed_nodes: dict[str, int]
      success: bool
graph.py
    import logging
    from typing import Optional

    import httpx
    from langgraph.graph import StateGraph, END
    from langchain_ollama import ChatOllama
    from langchain_core.messages import HumanMessage, SystemMessage

    from agentic_app.shared.error_handler import handle_node_errors, create_universal_router
    from agentic_app.supervisor_service.schema import SupervisorState

    logger = logging.getLogger(__name__)


    class SupervisorNodes:
        def __init__(self, research_url: str, analysis_url: str, llm: Optional[ChatOllama] = None):
            self.research_url = research_url
            self.analysis_url = analysis_url
            self.llm = llm or ChatOllama(model="gpt-oss", temperature=0)

        @handle_node_errors("call_research", "Failed to call research service")
        async def call_research(self, state: SupervisorState) -> dict:
            logger.info(f"Calling research service at {self.research_url}")

            async with httpx.AsyncClient(timeout=300.0) as client:
                response = await client.post(
                    f"{self.research_url}/research",
                    json={"query": state.original_query, "max_retries": state.max_retries}
                )
                response.raise_for_status()
                result = response.json()

            if not result.get("success"):
                raise Exception(f"Research service failed: {result.get('error_messages')}")

            return {
                "research_result": result
            }

        @handle_node_errors("call_analysis", "Failed to call analysis service")
        async def call_analysis(self, state: SupervisorState) -> dict:
            logger.info(f"Calling analysis service at {self.analysis_url}")

            research_summary = state.research_result.get("summary", "")

            async with httpx.AsyncClient(timeout=300.0) as client:
                response = await client.post(
                    f"{self.analysis_url}/analyze",
                    json={"research_summary": research_summary, "max_retries": state.max_retries}
                )
                response.raise_for_status()
                result = response.json()

            if not result.get("success"):
                raise Exception(f"Analysis service failed: {result.get('error_messages')}")

            return {
                "analysis_result": result
            }

        @handle_node_errors("generate_final_report", "Failed to generate final report")
        async def generate_final_report(self, state: SupervisorState) -> dict:
            logger.info("Generating final report")

            messages = [
                SystemMessage(content="Create a comprehensive final report combining research and analysis."),
                HumanMessage(
                    content=f"Query: {state.original_query}\n\nResearch: {state.research_result.get('summary')}\n\nAnalysis: {state.analysis_result.get('final_analysis')}")
            ]

            response = await self.llm.ainvoke(messages)

            return {
                "final_report": response.content
            }


    def create_supervisor_graph(research_url: str, analysis_url: str):
        nodes = SupervisorNodes(research_url, analysis_url)
        workflow = StateGraph(SupervisorState)

        workflow.add_node("call_research", nodes.call_research)
        workflow.add_node("call_analysis", nodes.call_analysis)
        workflow.add_node("generate_final_report", nodes.generate_final_report)

        workflow.set_entry_point("call_research")

        workflow.add_conditional_edges(
            "call_research",
            create_universal_router(next_node="call_analysis", node_name="call_research")
        )
        workflow.add_conditional_edges(
            "call_analysis",
            create_universal_router(next_node="generate_final_report", node_name="call_analysis")
        )
        workflow.add_conditional_edges(
            "generate_final_report",
            create_universal_router(next_node=END, node_name="generate_final_report")
        )

        return workflow.compile()
main.py
    from fastapi import FastAPI, HTTPException
    from fastapi.middleware.cors import CORSMiddleware
    import logging

    from agentic_app.research_service.graph import create_research_graph
    from agentic_app.research_service.schema import ResearchResponse, ResearchRequest, ResearchState

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    app = FastAPI(title="Research Service", version="1.0.0")

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    # Initialize graph
    research_graph = create_research_graph()


    @app.post("/research", response_model=ResearchResponse)
    async def conduct_research(request: ResearchRequest):
        """Conduct research on a given query"""
        try:
            logger.info(f"Received research request: {request.query}")

            initial_state = ResearchState(
                query=request.query,
                max_retries=request.max_retries
            )

            final_state = await research_graph.ainvoke(initial_state)

            return ResearchResponse(
                query=final_state.get("query", ""),
                research_plan=final_state.get("research_plan", ""),
                search_results=final_state.get("search_results", []),
                summary=final_state.get("summary", ""),
                error_messages=final_state.get("error_messages", []),
                failed_nodes=final_state.get("failed_nodes", {}),
                success=len(final_state.get("error_messages", [])) == 0
            )

        except Exception as e:
            logger.error(f"Research failed: {str(e)}")
            raise HTTPException(status_code=500, detail=str(e))


    @app.get("/health")
    async def health():
        return {"status": "healthy", "service": "research"}


    if __name__ == '__main__':
        import uvicorn
        uvicorn.run(app, host="0.0.0.0", port=8081)
Others
error_handler.py
import asyncio
import logging
from functools import wraps
from typing import Protocol, Optional, TypeVar, Any, Callable

logger = logging.getLogger(__name__)


class ErrorState(Protocol):
    error_messages: list[str]
    retry_count: int
    max_retries: int
    failed_nodes: dict[str, int]


StateType = TypeVar('StateType', bound=ErrorState)


class ErrorHandler:
    @staticmethod
    def handle_error(state: StateType, error: Exception, node_name: str, custom_message: Optional[str] = None) -> dict:
        error_msg = custom_message or f"Error in {node_name}: {str(error)}"
        logger.error(f"Node '{node_name}' failed: {str(error)}")

        failed_nodes = dict(
            getattr(state, 'failed_nodes', {}) if hasattr(state, 'failed_nodes') else state.get('failed_nodes', {}))
        node_retry_count = failed_nodes.get(node_name, 0) + 1
        failed_nodes[node_name] = node_retry_count

        retry_count = getattr(state, 'retry_count', 0) if hasattr(state, 'retry_count') else state.get('retry_count', 0)

        return {
            "error_messages": [f"{node_name}: {error_msg}"],
            "retry_count": retry_count + 1,
            "failed_nodes": failed_nodes
        }

    @staticmethod
    def should_retry(state: StateType, node_name: Optional[str] = None) -> bool:
        if node_name:
            failed_nodes = getattr(state, 'failed_nodes', {}) if hasattr(state, 'failed_nodes') else state.get(
                'failed_nodes', {})
            node_retry_count = failed_nodes.get(node_name, 0)
            max_retries = getattr(state, 'max_retries', 3) if hasattr(state, 'max_retries') else state.get(
                'max_retries', 3)
            error_messages = getattr(state, 'error_messages', []) if hasattr(state, 'error_messages') else state.get(
                'error_messages', [])
            has_node_error = any(node_name in msg for msg in error_messages)
            return node_retry_count < max_retries and has_node_error
        else:
            retry_count = getattr(state, 'retry_count', 0) if hasattr(state, 'retry_count') else state.get(
                'retry_count', 0)
            max_retries = getattr(state, 'max_retries', 3) if hasattr(state, 'max_retries') else state.get(
                'max_retries', 3)
            error_messages = getattr(state, 'error_messages', []) if hasattr(state, 'error_messages') else state.get(
                'error_messages', [])
            return retry_count < max_retries and len(error_messages) > 0

    @staticmethod
    def clear_errors(state: StateType) -> dict[str, Any]:
        return {
            "error_messages": [],
            "retry_count": 0,
        }


def handle_node_errors(node_name: str, custom_message: Optional[str] = None):
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        async def async_wrapper(self, state: StateType) -> dict[str, Any]:
            try:
                result = await func(self, state)
                if result is None:
                    result = {}
                result.update(ErrorHandler.clear_errors(state))
                return result
            except Exception as e:
                logger.exception(f"Error in async node '{node_name}'")
                return ErrorHandler.handle_error(state, e, node_name, custom_message)

        @wraps(func)
        def sync_wrapper(self, state: StateType) -> dict[str, Any]:
            try:
                result = func(self, state)
                if result is None:
                    result = {}
                result.update(ErrorHandler.clear_errors(state))
                return result
            except Exception as e:
                logger.exception(f"Error in sync node '{node_name}'")
                return ErrorHandler.handle_error(state, e, node_name, custom_message)

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper

    return decorator


def create_universal_router(next_node: str, end_node: str = "END", node_name: Optional[str] = None):
    def router(state) -> str:
        if isinstance(state, dict):
            error_messages = state.get('error_messages', [])
            max_retries = state.get('max_retries', 3)
            failed_nodes = state.get('failed_nodes', {})
        else:
            error_messages = state.error_messages
            max_retries = state.max_retries
            failed_nodes = state.failed_nodes

        if len(error_messages) > 0 and node_name:
            has_node_error = any(node_name in msg for msg in error_messages)

            if has_node_error:
                node_retry_count = failed_nodes.get(node_name, 0)
                if node_retry_count < max_retries:
                    logger.info(f"Retrying {node_name}, attempt {node_retry_count}/{max_retries}")
                    return node_name
                else:
                    logger.error(f"Max retries reached for {node_name}, ending execution")
                    return end_node

        return next_node

    return router
run_services.sh
# Terminal 1 - Research Service
cd research_service
uvicorn main:app --reload --port 8001

# Terminal 2 - Analysis Service
cd analysis_service
uvicorn main:app --reload --port 8002

# Terminal 3 - Supervisor Service
cd supervisor_service
uvicorn main:app --reload --port 8000

Codes Explanation#

Full Code(old)
Graph
  graph TD
    Start([Start]) --> ValidateQuery[Validate Query]

    ValidateQuery -->|Success| CreatePlan[Create Research Plan]
    ValidateQuery -->|Error & Retries| ValidateQuery
    ValidateQuery -->|Error & Max Retries| End([End])

    CreatePlan -->|Success| GatherInfo[Gather Information]
    CreatePlan -->|Error & Retries| CreatePlan
    CreatePlan -->|Error & Max Retries| End

    GatherInfo -->|Success| Synthesize[Synthesize Findings]
    GatherInfo -->|Error & Retries| GatherInfo
    GatherInfo -->|Error & Max Retries| End

    Synthesize -->|Success| GenerateReport[Generate Report]
    Synthesize -->|Error & Retries| Synthesize
    Synthesize -->|Error & Max Retries| End

    GenerateReport -->|Success| End
    GenerateReport -->|Error & Retries| GenerateReport
    GenerateReport -->|Error & Max Retries| End

    style ValidateQuery fill:#e1f5ff
    style CreatePlan fill:#e1f5ff
    style GatherInfo fill:#e1f5ff
    style Synthesize fill:#e1f5ff
    style GenerateReport fill:#e1f5ff
    style Start fill:#d4edda
    style End fill:#f8d7da
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import asyncio
import logging
from functools import wraps
from typing import Protocol, Optional, TypeVar, Any, Callable, Annotated
from langgraph.graph import StateGraph, END
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel, Field

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# ============================================================================
# ERROR HANDLER (from your code)
# ============================================================================

class ErrorState(Protocol):
    error_messages: list[str]
    retry_count: int
    max_retries: int
    last_failed_node: Optional[str]
    current_node: Optional[str]


StateType = TypeVar('StateType', bound=ErrorState)


class ErrorHandler:
    @staticmethod
    def handle_error(state: StateType, error: Exception, node_name: str, custom_message: Optional[str] = None) -> dict:
        error_msg = custom_message or f"Error in {node_name}: {str(error)}"
        logger.error(f"Node '{node_name}' failed: {str(error)}")

        return {
            "error_messages": [error_msg],
            "retry_count": state.retry_count + 1,
            "last_failed_node": node_name,
            "current_node": node_name
        }

    @staticmethod
    def should_retry(state: StateType) -> bool:
        return state.retry_count < state.max_retries and len(state.error_messages) > 0

    @staticmethod
    def clear_errors(state: StateType) -> dict[str, Any]:
        return {
            "error_messages": [],
            "retry_count": 0,
            "last_failed_node": None,
        }

    @staticmethod
    def get_error_summary(state: StateType) -> dict[str, Any]:
        return {
            "has_errors": len(state.error_messages) > 0,
            "error_count": len(state.error_messages),
            "retry_count": state.retry_count,
            "last_failed_node": state.last_failed_node,
            "can_retry": ErrorHandler.should_retry(state)
        }


def handle_node_errors(node_name: str, custom_message: Optional[str] = None):
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        async def async_wrapper(self, state: StateType) -> dict[str, Any]:
            try:
                result = await func(self, state)
                if result is None:
                    result = {}
                result.update(ErrorHandler.clear_errors(state))
                return result
            except Exception as e:
                logger.exception(f"Error in async node '{node_name}'")
                return ErrorHandler.handle_error(state, e, node_name, custom_message)

        @wraps(func)
        def sync_wrapper(self, state: StateType) -> dict[str, Any]:
            try:
                result = func(self, state)
                if result is None:
                    result = {}
                result.update(ErrorHandler.clear_errors(state))
                return result
            except Exception as e:
                logger.exception(f"Error in sync node '{node_name}'")
                return ErrorHandler.handle_error(state, e, node_name, custom_message)

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper

    return decorator


# ============================================================================
# STATE DEFINITION
# ============================================================================

class ResearchState(BaseModel):
    """State for the research assistant workflow"""
    query: str = ""
    research_plan: str = ""
    search_results: list[str] = Field(default_factory=list)
    summary: str = ""
    final_report: str = ""

    # Error handling fields
    error_messages: list[str] = Field(default_factory=list)
    retry_count: int = 0
    max_retries: int = 3
    last_failed_node: Optional[str] = None
    current_node: Optional[str] = None

    # Control flow
    should_continue: bool = True

    class Config:
        arbitrary_types_allowed = True


# ============================================================================
# RESEARCH NODES
# ============================================================================

class ResearchNodes:
    """Collection of nodes for the research workflow"""

    def __init__(self, llm: Optional[ChatOllama] = None):
        self.llm = llm or ChatOllama(model="gpt-oss", temperature=0)

    @handle_node_errors("validate_query", "Failed to validate the research query")
    def validate_query(self, state: ResearchState) -> dict[str, Any]:
        """Validate that the query is appropriate for research"""
        logger.info(f"Validating query: {state.query}")

        if not state.query or len(state.query.strip()) < 5:
            raise ValueError("Query must be at least 5 characters long")

        # Simulate potential validation issues
        if "error" in state.query.lower():
            raise ValueError("Query contains forbidden terms")

        return {
            "current_node": "validate_query",
            "should_continue": True
        }

    @handle_node_errors("create_research_plan", "Failed to create research plan")
    async def create_research_plan(self, state: ResearchState) -> dict[str, Any]:
        """Create a research plan based on the query"""
        logger.info(f"Creating research plan for: {state.query}")

        messages = [
            SystemMessage(content="You are a research planning assistant. Create a brief 3-step research plan."),
            HumanMessage(content=f"Create a research plan for: {state.query}")
        ]

        response = await self.llm.ainvoke(messages)

        if not response.content:
            raise ValueError("LLM returned empty research plan")

        return {
            "research_plan": response.content,
            "current_node": "create_research_plan",
            "should_continue": True
        }

    @handle_node_errors("gather_information", "Failed to gather information")
    async def gather_information(self, state: ResearchState) -> dict[str, Any]:
        """Simulate gathering information from various sources"""
        logger.info("Gathering information...")

        # Simulate API calls that might fail
        await asyncio.sleep(0.5)

        # Simulate random failures for demonstration
        import random
        if random.random() < 0.2:  # 20% chance of failure
            raise ConnectionError("Failed to connect to research database")

        # Simulate search results
        search_results = [
            f"Research finding 1 about {state.query}",
            f"Research finding 2 about {state.query}",
            f"Research finding 3 about {state.query}",
        ]

        return {
            "search_results": search_results,
            "current_node": "gather_information",
            "should_continue": True
        }

    @handle_node_errors("synthesize_findings", "Failed to synthesize findings")
    async def synthesize_findings(self, state: ResearchState) -> dict[str, Any]:
        """Synthesize the gathered information into a summary"""
        logger.info("Synthesizing findings...")

        if not state.search_results:
            raise ValueError("No search results available to synthesize")

        findings_text = "\n".join(f"- {result}" for result in state.search_results)

        messages = [
            SystemMessage(content="You are a research synthesis assistant. Summarize the findings concisely."),
            HumanMessage(
                content=f"Research Plan:\n{state.research_plan}\n\nFindings:\n{findings_text}\n\nProvide a brief summary.")
        ]

        response = await self.llm.ainvoke(messages)

        return {
            "summary": response.content,
            "current_node": "synthesize_findings",
            "should_continue": True
        }

    @handle_node_errors("generate_report", "Failed to generate final report")
    async def generate_report(self, state: ResearchState) -> dict[str, Any]:
        """Generate the final research report"""
        logger.info("Generating final report...")

        messages = [
            SystemMessage(content="You are a report writing assistant. Create a concise final report."),
            HumanMessage(content=f"Query: {state.query}\n\nSummary: {state.summary}\n\nCreate a final report.")
        ]

        response = await self.llm.ainvoke(messages)

        return {
            "final_report": response.content,
            "current_node": "generate_report",
            "should_continue": False
        }


# ============================================================================
# ROUTING LOGIC
# ============================================================================


def create_universal_router(next_node: str, end_node: str = END):
    """Create a universal router that handles errors and retries"""

    def router(state) -> str:
        # Handle both dict and Pydantic model
        if isinstance(state, dict):
            error_messages = state.get('error_messages', [])
            retry_count = state.get('retry_count', 0)
            max_retries = state.get('max_retries', 3)
            last_failed_node = state.get('last_failed_node', 'validate_query')
        else:
            error_messages = state.error_messages
            retry_count = state.retry_count
            max_retries = state.max_retries
            last_failed_node = state.last_failed_node or 'validate_query'

        if len(error_messages) > 0:
            if retry_count < max_retries:
                logger.info(f"Retrying {last_failed_node}, attempt {retry_count}/{max_retries}")
                return last_failed_node
            else:
                logger.error(f"Max retries reached for {last_failed_node}, ending execution")
                return end_node
        else:
            return next_node

    return router


def should_retry_node(state) -> str:
    """Route back to the failed node for retry"""
    if isinstance(state, dict):
        last_failed = state.get('last_failed_node', 'validate_query')
    else:
        last_failed = state.last_failed_node or 'validate_query'

    logger.info(f"Routing to retry node: {last_failed}")
    return last_failed





# ============================================================================
# GRAPH CONSTRUCTION
# ============================================================================

def create_research_graph(llm: Optional[ChatOllama] = None) -> CompiledStateGraph:
    """Create the research assistant graph with error handling"""

    nodes = ResearchNodes(llm)

    # Create the graph
    workflow = StateGraph(ResearchState)

    # Add nodes
    workflow.add_node("validate_query", nodes.validate_query)
    workflow.add_node("create_research_plan", nodes.create_research_plan)
    workflow.add_node("gather_information", nodes.gather_information)
    workflow.add_node("synthesize_findings", nodes.synthesize_findings)
    workflow.add_node("generate_report", nodes.generate_report)

    # Set entry point
    workflow.set_entry_point("validate_query")

    # Add conditional edges using universal router
    # Router will automatically retry the failed node or move to next node
    workflow.add_conditional_edges(
        "validate_query",
        create_universal_router(next_node="create_research_plan")
    )
    workflow.add_conditional_edges(
        "create_research_plan",
        create_universal_router(next_node="gather_information")
    )
    workflow.add_conditional_edges(
        "gather_information",
        create_universal_router(next_node="synthesize_findings")
    )
    workflow.add_conditional_edges(
        "synthesize_findings",
        create_universal_router(next_node="generate_report")
    )
    workflow.add_conditional_edges(
        "generate_report",
        create_universal_router(next_node=END)
    )

    return workflow.compile()
# ============================================================================
# MAIN EXECUTION
# ============================================================================

async def main():
    """Run the research assistant"""

    print("=" * 80)
    print("RESEARCH ASSISTANT WITH ERROR HANDLING")
    print("=" * 80)

    # Create the graph
    graph = create_research_graph()

    # Test queries
    queries = [
        "What are the latest developments in quantum computing?",
        "err",  # This will fail validation (too short)
        "Impact of artificial intelligence on healthcare",
    ]

    for i, query in enumerate(queries, 1):
        print(f"\n{'=' * 80}")
        print(f"QUERY {i}: {query}")
        print(f"{'=' * 80}\n")

        initial_state = ResearchState(query=query)

        try:

            final_state = await graph.ainvoke(initial_state)

            # Display results
            print("\n" + "=" * 80)
            print("RESULTS")
            print("=" * 80)

            # final_state is a dict, not a ResearchState object
            error_messages = final_state.get("error_messages", [])

            if error_messages:
                print(f"\n❌ FAILED with errors:")
                for error in error_messages:
                    print(f"  - {error}")
                print(f"\nRetry count: {final_state.get('retry_count', 0)}/{final_state.get('max_retries', 3)}")
            else:
                print(f"\n✅ SUCCESS!")
                research_plan = final_state.get('research_plan', 'N/A')
                summary = final_state.get('summary', 'N/A')
                final_report = final_state.get('final_report', 'N/A')

                print(f"\nResearch Plan:\n{research_plan[:200] if research_plan != 'N/A' else research_plan}...")
                print(f"\nSummary:\n{summary[:200] if summary != 'N/A' else summary}...")
                print(f"\nFinal Report:\n{final_report[:300] if final_report != 'N/A' else final_report}...")

            # Show error summary (convert dict to object-like for ErrorHandler)
            print(f"\nError Summary:")
            print(f"  - Has errors: {len(error_messages) > 0}")
            print(f"  - Error count: {len(error_messages)}")
            print(f"  - Retry count: {final_state.get('retry_count', 0)}")
            print(f"  - Last failed node: {final_state.get('last_failed_node', 'None')}")



        except Exception as e:
            print(f"\n❌ Unexpected error: {e}")

        await asyncio.sleep(1)


if __name__ == "__main__":
    asyncio.run(main())