"""
LangGraph CompiledGraph Adapter
Wraps a LangGraph CompiledGraph for use with the Reminix runtime.
Compatibility:
langgraph >= 1.0.0
"""
from __future__ import annotations
from typing import Any
from ..protocols import LangGraphCompiledGraphProtocol
from reminix.runtime import Agent
[docs]
def from_compiled_graph(
graph: LangGraphCompiledGraphProtocol,
*,
name: str,
output_key: str | None = None,
metadata: dict[str, Any] | None = None,
) -> Agent:
"""
Create a Reminix Agent from a LangGraph CompiledGraph.
Args:
graph: A LangGraph CompiledGraph instance (result of StateGraph.compile()).
name: Name for the Reminix agent.
output_key: Optional key to extract from the final state as output.
If None, returns the entire final state.
metadata: Optional metadata for the agent.
Returns:
A Reminix Agent that wraps the LangGraph graph.
Example::
from langgraph.graph import StateGraph, START, END
from typing import TypedDict
from reminix.adapters.langgraph import from_compiled_graph
from reminix.runtime import serve
class State(TypedDict):
input: str
output: str
def process(state: State) -> dict:
return {"output": f"Processed: {state['input']}"}
graph = (
StateGraph(State)
.add_node("process", process)
.add_edge(START, "process")
.add_edge("process", END)
.compile()
)
agent = from_compiled_graph(graph, name="processor", output_key="output")
serve(agent)
"""
agent = Agent(
name,
metadata={
"framework": "langgraph",
"adapter": "compiled-graph",
**(metadata or {}),
},
)
@agent.invoke # type: ignore[arg-type]
async def handle_invoke(input_data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
"""Non-streaming invoke via LangGraph."""
graph_input = _prepare_graph_input(input_data)
result = await graph.ainvoke(graph_input)
return {"output": _extract_output(result, output_key)}
@agent.invoke_stream # type: ignore[arg-type]
async def handle_invoke_stream(input_data: dict[str, Any], ctx: dict[str, Any]):
"""Streaming invoke via LangGraph."""
graph_input = _prepare_graph_input(input_data)
async for event in graph.astream(graph_input, stream_mode="updates"):
chunk = _extract_stream_chunk(event)
if chunk:
yield {"chunk": chunk}
@agent.chat # type: ignore[arg-type]
async def handle_chat(messages: list[dict[str, Any]], ctx: dict[str, Any]) -> dict[str, Any]:
"""Non-streaming chat via LangGraph."""
# Convert messages to a format the graph can use
last_message = messages[-1]["content"] if messages else ""
input_data = {"messages": messages, "input": last_message}
result = await graph.ainvoke(input_data)
output = _extract_output(result, output_key)
return {"message": {"role": "assistant", "content": output}}
@agent.chat_stream # type: ignore[arg-type]
async def handle_chat_stream(messages: list[dict[str, Any]], ctx: dict[str, Any]):
"""Streaming chat via LangGraph."""
last_message = messages[-1]["content"] if messages else ""
input_data = {"messages": messages, "input": last_message}
async for event in graph.astream(input_data, stream_mode="updates"):
chunk = _extract_stream_chunk(event)
if chunk:
yield {"chunk": chunk}
return agent
def _prepare_graph_input(input_data: dict[str, Any]) -> dict[str, Any]:
"""
Prepare input for the LangGraph graph.
Converts prompt-based input to a messages format if needed.
If input already has 'messages', pass through as-is.
"""
# If input already has messages, use as-is
if "messages" in input_data:
return input_data
# Extract prompt from common keys
prompt = input_data.get("prompt") or input_data.get("input") or ""
if not prompt and len(input_data) == 1:
# Single key input - use its value as prompt
prompt = str(next(iter(input_data.values())))
# Try to import HumanMessage from langchain_core
try:
from langchain_core.messages import HumanMessage
return {"messages": [HumanMessage(content=str(prompt))]}
except ImportError:
# Fallback to dict format
return {"messages": [{"role": "user", "content": str(prompt)}]}
def _extract_output(result: Any, output_key: str | None) -> str:
"""Extract output from the final graph state."""
if output_key and isinstance(result, dict):
value = result.get(output_key)
if value is not None:
return str(value)
# Try common output keys
if isinstance(result, dict):
for key in ["output", "response", "result", "answer"]:
if key in result and result[key] is not None:
return str(result[key])
# If messages key exists, get last assistant message
if "messages" in result and result["messages"]:
last_msg = result["messages"][-1]
if hasattr(last_msg, "content"):
return str(last_msg.content)
if isinstance(last_msg, dict) and "content" in last_msg:
return str(last_msg["content"])
if isinstance(result, str):
return result
return str(result)
def _extract_stream_chunk(event: Any) -> str | None:
"""Extract content from a streaming event."""
if isinstance(event, dict):
# LangGraph streams as {node_name: state_update}
for node_name, update in event.items():
if isinstance(update, dict):
# Check for common content keys
for key in ["output", "response", "result", "content"]:
if key in update and update[key]:
return str(update[key])
# Check for messages
if "messages" in update and update["messages"]:
last_msg = update["messages"][-1]
if hasattr(last_msg, "content"):
return str(last_msg.content)
if isinstance(last_msg, dict) and "content" in last_msg:
return str(last_msg["content"])
elif isinstance(update, str):
return update
return None