import streamlit as st import os from langchain_groq import ChatGroq from langgraph.graph import MessagesState, StateGraph, START from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import HumanMessage, AIMessage # Load API key from Hugging Face Secrets GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: st.error("GROQ_API_KEY not found. Please add it as a Secret in Hugging Face Spaces.") st.stop() # Initialize the Groq LLM llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192") # Helper functions def extract_display_text(message): if "Rewritten text in" in message: return message.split(": ", 1)[1] elif "Text approved:" in message: return message return message # Assistant function def assistant(state: MessagesState): messages = state["messages"] latest_msg = messages[-1].content if "Rewrite this text in" in latest_msg: tone = latest_msg.split(" in ")[1].split(" tone:")[0] text = latest_msg.split("tone: ")[1] response = llm.invoke(f"Rewrite the following legal text in a {tone} tone: {text}") return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]} elif latest_msg == "regenerate": for msg in messages: if isinstance(msg, HumanMessage) and "Rewrite this text in" in msg.content: original_text = msg.content.split("tone: ")[1] tone = msg.content.split(" in ")[1].split(" tone:")[0] response = llm.invoke(f"Rewrite the following legal text in a {tone} tone: {original_text}") return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]} elif latest_msg.startswith("feedback:"): feedback = latest_msg.split("feedback: ")[1].strip() for msg in reversed(messages): if isinstance(msg, AIMessage) and "Rewritten text in" in msg.content: text = msg.content.split(": ", 1)[1] tone = msg.content.split(" in ")[1].split(" tone")[0] response = llm.invoke(f"Refine the following text based on this feedback: '{feedback}'. Maintain the {tone} tone. Text: {text}") return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]} elif latest_msg.lower() == "approve": for msg in reversed(messages): if isinstance(msg, AIMessage) and "Rewritten text in" in msg.content: return {"messages": [AIMessage(content=f"Text approved: {msg.content.split(': ', 1)[1]}")]} return {"messages": [AIMessage(content="Invalid command.")]} # Human feedback node (stop condition) def human_feedback(state: MessagesState): return state # Build the LangGraph pipeline builder = StateGraph(MessagesState) builder.add_node("assistant", assistant) builder.add_node("human_feedback", human_feedback) # Define edges with a stopping condition builder.add_edge(START, "assistant") builder.add_edge("assistant", "human_feedback") # Stops after assistant builder.add_edge("human_feedback", "assistant") # Only runs if more input is provided memory = MemorySaver() graph = builder.compile(interrupt_before=["human_feedback"], checkpointer=memory) # Streamlit UI st.title("📜 Legal Text Rewriter") st.markdown("Rewrite legal text into different tones using AI.") # Store session state if "state" not in st.session_state: st.session_state.state = {"configurable": {"thread_id": "1"}} # User input section legal_text = st.text_area("Enter Legal Text:", "") initial_tone = st.selectbox("Select Tone:", ["Formal", "Empathetic", "Neutral", "Strength-Based"]) submit_btn = st.button("Submit") if submit_btn and legal_text.strip(): command = f"Rewrite this text in {initial_tone} tone: {legal_text}" graph.update_state(st.session_state.state, {"messages": [HumanMessage(content=command)]}) for event in graph.stream(None, st.session_state.state, stream_mode="values"): if "messages" in event: st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) # Display rewritten text if "rewritten_text" in st.session_state: st.subheader("Rewritten Text") st.write(st.session_state.rewritten_text) # Buttons col1, col2, col3 = st.columns([1, 2, 1]) with col1: if st.button("🔄 Regenerate"): graph.update_state(st.session_state.state, {"messages": [HumanMessage(content="regenerate")]}) for event in graph.stream(None, st.session_state.state, stream_mode="values"): if "messages" in event: st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) with col2: feedback_text = st.text_input("Enter Feedback:") if st.button("💬 Submit Feedback"): graph.update_state(st.session_state.state, {"messages": [HumanMessage(content=f'feedback: {feedback_text}')]}) for event in graph.stream(None, st.session_state.state, stream_mode="values"): if "messages" in event: st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) with col3: if st.button("✅ Approve"): graph.update_state(st.session_state.state, {"messages": [HumanMessage(content="approve")]}) for event in graph.stream(None, st.session_state.state, stream_mode="values"): if "messages" in event: st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content)