While LangChain offers a versatile suite of built-in memory modules, production applications often encounter unique contextual requirements that standard implementations cannot fully address. You might need to integrate with proprietary databases, enforce domain-specific context filtering, implement novel summarization techniques, or manage conversational state in ways tailored to your specific use case. This is where developing custom memory modules becomes necessary.
Building a custom memory module allows you to define precisely how conversational history is stored, retrieved, and manipulated, providing fine-grained control over the context provided to your LLMs, chains, and agents.
At its core, any LangChain memory module must adhere to a specific interface. The primary base class you'll typically interact with is BaseChatMemory
, which itself inherits from BaseMemory
. To create a functional custom memory module, you need to implement a few key methods:
__init__(self, ...)
: Your constructor. This is where you'll initialize any necessary resources, such as database connections, configuration parameters (like history length limits), or internal state variables. You'll also typically call the parent class constructor (super().__init__(...)
).memory_variables(self) -> List[str]
: This property should return a list of string keys that your memory module expects to inject into the prompt. The most common key is "history"
, but you might define others depending on your logic.load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]
: This is the crucial method for retrieving context. It receives a dictionary of the current inputs to the chain or agent (excluding memory variables). Your implementation should fetch the relevant conversational history or state based on these inputs and return a dictionary where keys match those defined in memory_variables
and values contain the formatted context (e.g., a string of past messages).save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None
: This method is called after the chain or agent execution. It receives the original inputs (excluding memory variables) and the final outputs generated by the LLM. Your task is to process this new interaction turn and store it appropriately according to your custom logic (e.g., append to an internal list, save to a database, update entity information).clear(self) -> None
: This method should reset the memory's state, effectively clearing the conversation history.Here's a conceptual diagram illustrating the interaction flow:
Interaction flow between a LangChain execution unit and a custom memory module's core methods.
The power of custom memory lies in the logic you implement within load_memory_variables
and save_context
. Consider these possibilities:
load_memory_variables
or periodically within save_context
.k
turns, dynamically determine how much history to load based on the current input's complexity or estimated token count.load_memory_variables
step.RelevantTurnMemory
Let's sketch a simple custom memory module that only stores and retrieves turns containing specific keywords. This is a basic example, but it illustrates the core implementation pattern.
import re
from typing import Any, Dict, List
from langchain.memory.chat_memory import BaseChatMemory
from langchain_core.messages import BaseMessage, get_buffer_string
# Assume relevant_keywords is defined elsewhere, e.g., ["project alpha", "budget", "milestone"]
relevant_keywords = ["project alpha", "budget", "milestone"]
class RelevantTurnMemory(BaseChatMemory):
"""
Memory that only stores and retrieves conversational turns
containing specific keywords in the input or output.
"""
relevant_turns: List[BaseMessage] = []
memory_key: str = "relevant_history" # Custom memory key
input_key: str = "input" # Optional: Specify if not default 'input'
output_key: str = "output" # Optional: Specify if not default 'output'
return_messages: bool = False # Control output format (string vs. List[BaseMessage])
def __init__(self, keywords: List[str], return_messages: bool = False,
memory_key: str = "relevant_history",
input_key: str = "input", output_key: str = "output"):
super().__init__()
self.keywords_pattern = re.compile("|".join(map(re.escape, keywords)), re.IGNORECASE)
self.relevant_turns = []
self.return_messages = return_messages
self.memory_key = memory_key
self.input_key = input_key
self.output_key = output_key
@property
def memory_variables(self) -> List[str]:
"""The list of keys returned by load_memory_variables."""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Retrieve the relevant conversation history."""
if self.return_messages:
# Return BaseMessage objects
context: Any = self.relevant_turns
else:
# Return a formatted string
context = get_buffer_string(
self.relevant_turns,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: context}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context to memory if relevant."""
input_str = inputs.get(self.input_key, "")
output_str = outputs.get(self.output_key, "")
is_relevant = bool(self.keywords_pattern.search(input_str)) or \
bool(self.keywords_pattern.search(output_str))
if is_relevant:
# Use the base class's method to add AI and Human messages
# Note: This assumes inputs contain a single key defined by self.input_key
# and outputs contain a single key defined by self.output_key.
# Adjust if your input/output structure is different.
input_message, output_message = self._get_input_output(inputs, outputs)
self.chat_memory.add_messages([input_message, output_message])
# Update our internal list (could also directly manipulate self.chat_memory.messages)
self.relevant_turns.extend([input_message, output_message])
def clear(self) -> None:
"""Clear the relevant memory."""
self.relevant_turns = []
self.chat_memory.clear() # Also clear the underlying ChatMessageHistory
# --- How to use it ---
# from langchain_openai import ChatOpenAI
# from langchain.chains import ConversationChain
# llm = ChatOpenAI(model="gpt-4", temperature=0)
# custom_memory = RelevantTurnMemory(keywords=["billing", "invoice", "payment"])
# conversation = ConversationChain(
# llm=llm,
# memory=custom_memory,
# verbose=True
# )
# # Example interaction
# print(conversation.predict(input="Hello, how are you?"))
# # -> This turn might not be saved if keywords aren't present.
# print(conversation.predict(input="I have a question about my recent invoice."))
# # -> This turn SHOULD be saved due to the keyword "invoice".
# print(conversation.predict(input="What was the total amount?"))
# # -> When loading memory for this turn, the previous relevant turn is included.
Using your custom memory module is straightforward. Instantiate your custom class and pass the instance to the memory
parameter of your LangChain Chain
or AgentExecutor
:
# Assuming MyCustomDatabaseMemory is defined as shown above
# and properly handles DB connections in __init__
# db_connection_string = "your_db_connection_details"
# custom_db_memory = MyCustomDatabaseMemory(connection_string=db_connection_string)
# agent_executor = AgentExecutor.from_agent_and_tools(
# agent=my_agent,
# tools=my_tools,
# memory=custom_db_memory, # Pass the custom memory instance
# verbose=True
# )
# # Now agent_executor will use your custom logic for memory
# agent_executor.run("User query...")
save_context
and __init__
(or a dedicated loading method) to interact with databases, files, or other persistent stores. Ensure robust error handling for I/O operations.arun
, acall
), your custom memory should ideally implement the asynchronous counterparts: aload_memory_variables
and asave_context
. These methods should perform I/O operations using async
/await
compatible libraries (e.g., aiohttp
, asyncpg
).self.relevant_turns
in the example) is serializable (e.g., using Python's pickle
or converting to/from JSON). BaseMessage
objects provided by LangChain are generally serializable.load_memory_variables
returns the expected context under various conditions, save_context
stores data correctly, and clear
functions as intended. Test edge cases like empty history or maximum history limits.Developing custom memory modules unlocks a significant level of control and customization for your LangChain applications. By understanding the core interface and carefully designing your storage and retrieval logic, you can build sophisticated conversational systems capable of managing context in ways perfectly aligned with your application's unique requirements.
© 2025 ApX Machine Learning