Source code for reminix.adapters.langchain.agent

"""
LangChain Agent Adapter

Wraps a LangChain agent (created via create_agent) for use with the Reminix runtime.
This adapter works with the modern LangChain agent API that returns a CompiledGraph.

Compatibility:
    langchain >= 1.0.0
"""

from __future__ import annotations

from typing import Any, AsyncIterator, Protocol, runtime_checkable

from reminix.runtime import Agent


@runtime_checkable
class LangChainAgentProtocol(Protocol):
    """Protocol for LangChain agents created via create_agent()."""

    async def ainvoke(self, input: dict[str, Any]) -> dict[str, Any]:
        """Async invoke the agent."""
        ...

    def astream(
        self, input: dict[str, Any], *, stream_mode: str = "updates"
    ) -> AsyncIterator[dict[str, Any]]:
        """Async stream from the agent."""
        ...


[docs] def from_langchain_agent( langchain_agent: LangChainAgentProtocol, *, name: str, output_key: str | None = None, metadata: dict[str, Any] | None = None, ) -> Agent: """ Create a Reminix Agent from a LangChain agent. This adapter works with agents created using LangChain's `create_agent()` API, which returns a CompiledGraph-like object. Args: langchain_agent: A LangChain agent (result of create_agent()). name: Name for the Reminix agent. output_key: Optional key to extract from the final state as output. metadata: Optional metadata for the agent. Returns: A Reminix Agent that wraps the LangChain agent. Example:: from langchain.agents import create_agent from reminix.adapters.langchain import from_langchain_agent from reminix.runtime import serve agent = create_agent( model="openai:gpt-4o", tools=[search_tool, calculator_tool], ) reminix_agent = from_langchain_agent(agent, name="smart-agent") serve(reminix_agent) """ agent = Agent( name, metadata={ "framework": "langchain", "adapter": "agent", **(metadata or {}), }, ) @agent.invoke # type: ignore[arg-type] async def handle_invoke(input_data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]: """Non-streaming invoke via LangChain agent.""" agent_input = _prepare_agent_input(input_data) result = await langchain_agent.ainvoke(agent_input) return {"output": _extract_output(result, output_key)} @agent.invoke_stream # type: ignore[arg-type] async def handle_invoke_stream(input_data: dict[str, Any], ctx: dict[str, Any]): """Streaming invoke via LangChain agent.""" agent_input = _prepare_agent_input(input_data) async for event in langchain_agent.astream(agent_input, stream_mode="updates"): chunk = _extract_stream_chunk(event) if chunk: yield {"chunk": chunk} @agent.chat # type: ignore[arg-type] async def handle_chat(messages: list[dict[str, Any]], ctx: dict[str, Any]) -> dict[str, Any]: """Non-streaming chat via LangChain agent.""" last_message = messages[-1]["content"] if messages else "" input_data = {"messages": messages, "input": last_message} result = await langchain_agent.ainvoke(input_data) output = _extract_output(result, output_key) return {"message": {"role": "assistant", "content": output}} @agent.chat_stream # type: ignore[arg-type] async def handle_chat_stream(messages: list[dict[str, Any]], ctx: dict[str, Any]): """Streaming chat via LangChain agent.""" last_message = messages[-1]["content"] if messages else "" input_data = {"messages": messages, "input": last_message} async for event in langchain_agent.astream(input_data, stream_mode="updates"): chunk = _extract_stream_chunk(event) if chunk: yield {"chunk": chunk} return agent
def _prepare_agent_input(input_data: dict[str, Any]) -> dict[str, Any]: """ Prepare input for the LangChain agent. Converts prompt-based input to a messages format if needed. """ # If input already has messages, use as-is if "messages" in input_data: return input_data # Extract prompt from common keys prompt = input_data.get("prompt") or input_data.get("input") or "" if not prompt and len(input_data) == 1: # Single key input - use its value as prompt prompt = str(next(iter(input_data.values()))) # Try to import HumanMessage from langchain_core try: from langchain_core.messages import HumanMessage return {"messages": [HumanMessage(content=str(prompt))]} except ImportError: # Fallback to dict format return {"messages": [{"role": "user", "content": str(prompt)}]} def _extract_output(result: Any, output_key: str | None) -> str: """Extract output from the final agent state.""" if output_key and isinstance(result, dict): value = result.get(output_key) if value is not None: return str(value) # Try common output keys if isinstance(result, dict): for key in ["output", "response", "result", "answer"]: if key in result and result[key] is not None: return str(result[key]) # If messages key exists, get last assistant message if "messages" in result and result["messages"]: last_msg = result["messages"][-1] if hasattr(last_msg, "content"): return str(last_msg.content) if isinstance(last_msg, dict) and "content" in last_msg: return str(last_msg["content"]) if isinstance(result, str): return result return str(result) def _extract_stream_chunk(event: Any) -> str | None: """Extract content from a streaming event.""" if isinstance(event, dict): # LangChain/LangGraph streams as {node_name: state_update} for node_name, update in event.items(): if isinstance(update, dict): # Check for common content keys for key in ["output", "response", "result", "content"]: if key in update and update[key]: return str(update[key]) # Check for messages if "messages" in update and update["messages"]: last_msg = update["messages"][-1] if hasattr(last_msg, "content"): return str(last_msg.content) if isinstance(last_msg, dict) and "content" in last_msg: return str(last_msg["content"]) elif isinstance(update, str): return update return None