-
Notifications
You must be signed in to change notification settings - Fork 7k
Seeking Advice on Optimizing SageMaker/Hugging Face Endpoint for Cypher Query Generation #4723
-
Hello everyone,
I am currently working on a system that generates Cypher queries based on a user's question and then answers the question based on the query results. I am using two separate Llama-3 8B endpoints on AWS SageMaker to accomplish this. As a newcomer to this field, I would greatly appreciate any advice on optimizing my setup. Here are a few areas where I need guidance:
Improving Efficiency: Are there better ways to achieve this functionality rather than using two separate endpoints?
Message Formatting: I am struggling with formatting my messages, resulting in a significant amount of post-processing. Any tips on how to streamline this process would be highly beneficial.
Model Training: I am training the model on specific data (attached below). I would appreciate any suggestions on improving the training process or the data itself.
Here is my code:
`
from sagemaker.predictor import Predictor from datetime import datetime from typing import List, Dict from huggingface_hub import login # Neo4j connection details uri = uri user = "neo4j" password = password # Create a Neo4j driver instance driver = GraphDatabase.driver(uri, auth=(user, password)) def infer_type(value): if isinstance(value, int): return "Integer" elif isinstance(value, float): return "Float" elif isinstance(value, bool): return "Boolean" elif isinstance(value, str): return "String" elif isinstance(value, dict): return "Map" elif isinstance(value, list): return "List" else: return "Unknown" def get_node_properties_by_label(label, driver): query = f""" MATCH (n:{label}) RETURN n """ with driver.session() as session: result = session.run(query) properties = {} for record in result: node = record["n"] for key, value in dict(node).items(): if key not in properties: properties[key] = infer_type(value) return properties def get_all_labels(driver): query = "CALL db.labels() YIELD label RETURN label" with driver.session() as session: result = session.run(query) labels = [record["label"] for record in result] return labels def get_relationship_properties(driver, relationship_type): query = f""" MATCH ()-[r:{relationship_type}]->() RETURN r LIMIT 1 """ with driver.session() as session: result = session.run(query) properties = {} for record in result: relationship = record["r"] for key, value in dict(relationship).items(): if key not in properties: properties[key] = infer_type(value) return properties def get_relationships_with_properties(driver): query = """ MATCH (start)-[r]->(end) RETURN DISTINCT labels(start)[0] AS start_label, type(r) AS relationship_type, labels(end)[0] AS end_label """ with driver.session() as session: result = session.run(query) relationships = [] for record in result: relationship_type = record["relationship_type"] relationship_properties = get_relationship_properties(driver, relationship_type) relationships.append({ "start_node": record["start_label"], "relationship_type": relationship_type, "end_node": record["end_label"], "properties": relationship_properties }) return relationships def get_neo4j_schema_with_properties_and_relationships(driver): labels = get_all_labels(driver) schema = {"nodes": {}, "relationships": []} for label in labels: properties = get_node_properties_by_label(label, driver) schema["nodes"][label] = [properties] schema["relationships"] = get_relationships_with_properties(driver) return schema def initialize_aws_session(): sts_client = boto3.client('sts') assumed_role = sts_client.assume_role( {MY_ROLE} ) credentials = assumed_role['Credentials'] assumed_session = boto3.Session( aws_access_key_id=credentials['AccessKeyId'], aws_secret_access_key=credentials['SecretAccessKey'], aws_session_token=credentials['SessionToken'], ) return sagemaker.Session(boto_session=assumed_session) def upload_file_to_s3(file_path, bucket, s3_path, sagemaker_session): s3_client = sagemaker_session.boto_session.client('s3') try: s3_client.upload_file(file_path, bucket, s3_path) print(f"File {file_path} uploaded to s3://{bucket}/{s3_path}") except ClientError as e: print(f"Failed to upload {file_path} to S3: {e}") def upload_training_data(sagemaker_session): output_bucket = sagemaker_session.default_bucket() train_data_location_cypher = f"s3://{output_bucket}/cypher_train_dataset/cypherquerygen.jsonl" train_data_location_answer = f"s3://{output_bucket}/answer_train_dataset/answergen.jsonl" template_location_cypher = f"s3://{output_bucket}/cypher_train_dataset/template.json" template_location_answer = f"s3://{output_bucket}/answer_train_dataset/template.json" # Local paths to your training data files local_cypher_train_file = 'Model/cypherquerygen.jsonl' local_answer_train_file = 'Model/answergen.jsonl' local_cypher_template_file = 'Model/template.json' local_answer_template_file = 'Model/template.json' # Upload files to S3 upload_file_to_s3(local_cypher_train_file, output_bucket, 'cypher_train_dataset/cypherquerygen.jsonl', sagemaker_session) upload_file_to_s3(local_answer_train_file, output_bucket, 'answer_train_dataset/answergen.jsonl', sagemaker_session) upload_file_to_s3(local_cypher_template_file, output_bucket, 'cypher_train_dataset/template.json', sagemaker_session) upload_file_to_s3(local_answer_template_file, output_bucket, 'answer_train_dataset/template.json', sagemaker_session) print(f"CypherQueryGen Training data: {train_data_location_cypher}") print(f"AnswerGen Training data: {train_data_location_answer}") return train_data_location_cypher, train_data_location_answer, template_location_cypher, template_location_answer def authenticate_huggingface(): HUGGINGFACE_TOKEN = "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" login(HUGGINGFACE_TOKEN) def initialize_tokenizer(): tokenizer = LlamaTokenizerFast.from_pretrained('meta-llama/Meta-Llama-3-8B') print(tokenizer.tokenize("This is a test sentence.")) return tokenizer def set_hyperparameters(instruction_tuned=True, chat_dataset=False): if instruction_tuned and chat_dataset: raise ValueError("Both instruction_tuned and chat_dataset cannot be True at the same time.") hyperparameters = { "instruction_tuned": str(instruction_tuned), "chat_dataset": str(chat_dataset), "epoch": "5", "max_input_length": "512", "preprocessing_num_workers": "1", "per_device_train_batch_size": "1", # Reduce batch size to fit in GPU memory "gradient_accumulation_steps": "16", # Accumulate gradients over 16 steps "fp16": "true" # Enable mixed precision training } return hyperparameters def create_estimator(model_id, model_version, role, environment, instance_type, train_data_location, job_name, sagemaker_session): estimator = JumpStartEstimator( model_id=model_id, model_version=model_version, role=role, environment=environment, disable_output_compression=True, instance_type=instance_type, sagemaker_session=sagemaker_session ) hyperparameters = set_hyperparameters() estimator.set_hyperparameters(**hyperparameters) # Start the training job asynchronously estimator.fit({"training": train_data_location}, wait=False, job_name=job_name) return estimator def deploy_model(estimator, endpoint_name, sagemaker_session): inference_image_uri = get_huggingface_llm_image_uri("huggingface", version="2.0.2") model = Model( image_uri=inference_image_uri, model_data=estimator.model_data, role="arn:aws:iam::975050073207:role/SageMaker_Capability", sagemaker_session=sagemaker_session, env=estimator.environment ) predictor = model.deploy( initial_instance_count=1, instance_type="ml.g5.24xlarge", # Use supported instance with more memory for deployment endpoint_name=endpoint_name, model_data_download_timeout=3600, container_startup_health_check_timeout=3600 ) return predictor def wait_for_training_completion(job_name, assumed_session): sm_client = assumed_session.client("sagemaker") while True: response = sm_client.describe_training_job(TrainingJobName=job_name) status = response["TrainingJobStatus"] if status in ["Completed", "Failed", "Stopped"]: print(f"Training job {job_name} status: {status}") if status != "Completed": raise Exception(f"Training job {job_name} did not complete successfully.") break time.sleep(60) def endpoint_exists(endpoint_name, assumed_session): try: sm_client = assumed_session.client("sagemaker") response = sm_client.describe_endpoint(EndpointName=endpoint_name) if response['EndpointStatus'] == 'InService': return True else: return False except sm_client.exceptions.ResourceNotFound: return False except ClientError as error: print(f"Client error while checking endpoint: {error}") return False def format_messages(messages): """ Format messages according to the Llama 3 chat template. Each message is formatted with the role, followed by the content. """ formatted_messages = [] for message in messages: formatted_messages.append(f"{message['role']}\n\n{message['content']}\n\n") # Add an empty assistant role for the model to generate a response formatted_messages.append("assistant\n\n") return "".join(formatted_messages) def execute_cypher_query(query): with driver.session() as session: result = session.run(query) return [record.data() for record in result] def clean_generated_query(query: str, schema: dict) -> str: """Remove formatting strings and extraneous text from the generated Cypher query.""" to_remove = ["<<SYS>>", "<</SYS>>", "<s>[INST]", "[/INST]", "[SYS]", "</s>", "SYS", "INST", " s", "Output:", "The query above"] for item in to_remove: query = query.replace(item, "") # Extract Cypher query up to the first semicolon query = query.split(';')[0].strip() # Remove any duplicate lines and unnecessary whitespace query_lines = query.split('\n') unique_lines = [] for line in query_lines: line = line.strip() if line and line not in unique_lines: unique_lines.append(line) cleaned_query = ' '.join(unique_lines).strip() # Replace node labels and relationship types with correct casing from schema node_labels = schema["nodes"].keys() relationship_types = [rel["relationship_type"] for rel in schema["relationships"]] for label in node_labels: cleaned_query = re.sub(rf'\b{label}\b', label.lower(), cleaned_query, flags=re.IGNORECASE) for rel_type in relationship_types: cleaned_query = re.sub(rf'\b{rel_type}\b', rel_type.upper(), cleaned_query, flags=re.IGNORECASE) return re.sub(r'\s+', ' ', cleaned_query) # Replace multiple whitespace with a single space def clean_final_response(response: str) -> str: """Remove formatting strings from the final response and strip whitespace.""" to_remove = ["<<SYS>>", "<</SYS>>", "<s>[INST]", "[/INST]", "[SYS]", "</s>", "SYS", "INST", " s"] for item in to_remove: response = response.replace(item, "") # Retain only specific characters and numbers response = re.sub(r'[^\w\s:.]', '', response) response = re.sub(r'\s+', ' ', response).strip() # Replace multiple whitespace with a single space and strip # Extract the concise answer after the keyword "Results: " start_index = response.find("Results: ") if start_index != -1: response = response[start_index + len("Results: "):] # Extract the numerical answer match = re.search(r'\d+', response) if match: return match.group(0) return response def handle_question(question, context): sagemaker_session = initialize_aws_session() schema = get_neo4j_schema_with_properties_and_relationships(driver) # Define schema here # Check if CypherQueryGen and AnswerGen endpoints exist cypher_endpoint_name = "CypherQueryGen" answer_endpoint_name = "AnswerGen" cypher_exists = endpoint_exists(cypher_endpoint_name, sagemaker_session.boto_session) answer_exists = endpoint_exists(answer_endpoint_name, sagemaker_session.boto_session) if not cypher_exists or not answer_exists: print("One or both endpoints do not exist. Creating them.") # Upload training data before creating the training job train_data_location_cypher, train_data_location_answer, template_location_cypher, template_location_answer = upload_training_data(sagemaker_session) authenticate_huggingface() tokenizer = initialize_tokenizer() if not cypher_exists: cypher_query_gen_estimator = create_estimator( model_id="meta-textgeneration-llama-3-8b-instruct", model_version="*", role="arn:aws:iam::975050073207:role/SageMaker_Capability", environment={ "accept_eula": "true", "TOKENIZER_PATH": 'meta-llama/Meta-Llama-3-8B', # Pass the tokenizer path as an environment variable "HF_TASK": "text-generation", # Specify the task "HF_MODEL_ID": "/opt/ml/model" # Specify the model ID path }, instance_type="ml.g5.24xlarge", train_data_location=train_data_location_cypher, job_name=f"cypher-query-gen-{datetime.now().strftime('%Y%m%d%H%M%S')}", sagemaker_session=sagemaker_session ) if not answer_exists: answer_gen_estimator = create_estimator( model_id="meta-textgeneration-llama-3-8b-instruct", model_version="*", role="arn:aws:iam::975050073207:role/SageMaker_Capability", environment={ "accept_eula": "true", "TOKENIZER_PATH": 'meta-llama/Meta-Llama-3-8B', # Pass the tokenizer path as an environment variable "HF_TASK": "text-generation", # Specify the task "HF_MODEL_ID": "/opt/ml/model" # Specify the model ID path }, instance_type="ml.g5.24xlarge", train_data_location=train_data_location_answer, job_name=f"answer-gen-{datetime.now().strftime('%Y%m%d%H%M%S')}", sagemaker_session=sagemaker_session ) # Wait for both training jobs to complete if not cypher_exists: wait_for_training_completion(cypher_query_gen_estimator.latest_training_job.name, sagemaker_session.boto_session) deploy_model(cypher_query_gen_estimator, cypher_endpoint_name, sagemaker_session) if not answer_exists: wait_for_training_completion(answer_gen_estimator.latest_training_job.name, sagemaker_session.boto_session) deploy_model(answer_gen_estimator, answer_endpoint_name, sagemaker_session) # Generate Cypher query predictor = Predictor(endpoint_name=cypher_endpoint_name, sagemaker_session=sagemaker_session) generate_query_prompt = [ {"role": "system", "content": "You are a Cypher query writing robot who only ever outputs cypher code. Your querries should be general enogh to get all information that may be useful. You only speak in Cypher queries, and should only ever output the Cypher query. You should never output any explanation or any other text other than the Cypher query."}, {"role": "user", "content": f"Generate a Cypher query to that gets enough information that is available answer the following question based on the schema provided. Your cypher query should be as simple and concise as possible. It is better to return more nodes/relationships/information than less, so just write a simple query. \n\nQuestion: {question}\nSchema: {json.dumps(schema, indent=2)}\nEntities: {context['entities']}\nRelated Questions: {context['related_questions']}\nChat History: {context['chat_history']}"} ] formatted_prompt = format_messages(generate_query_prompt) payload = { "inputs": formatted_prompt, "parameters": { "max_new_tokens": 256, "do_sample": True, "temperature": 0.6, "top_p": 0.9, "return_full_text": False, } } body = json.dumps(payload) predictor.content_type = 'application/json' predictor.accept = 'application/json' print(f"Payload: {body}") try: response = predictor.predict(body) response_data = json.loads(response.decode('utf-8')) cypher_query = response_data[0]['generated_text'].strip() cleaned_cypher_query = clean_generated_query(cypher_query, schema) print(f"Generated Cypher Query: {cleaned_cypher_query}") results = execute_cypher_query(cleaned_cypher_query) results_str = json.dumps(results) print(f"Query Results: {results_str}") # Generate final answer answer_question_prompt = [ {"role": "system", "content": "You are a robot that answers questions based off of query results. Your goal is to be as concise and to the point as possible."}, {"role": "user", "content": f"Based on the following question and query results, provide a concise answer. Question: {question} Results: {results_str} Context: {context}"} ] formatted_answer_prompt = format_messages(answer_question_prompt) body = json.dumps({"inputs": formatted_answer_prompt, "parameters": payload["parameters"]}) print(f"Final Payload: {body}") answer_predictor = Predictor(endpoint_name=answer_endpoint_name, sagemaker_session=sagemaker_session) answer_predictor.content_type = 'application/json' answer_predictor.accept = 'application/json' final_response = answer_predictor.predict(body) final_response_data = json.loads(final_response.decode('utf-8')) cleaned_final_response = clean_final_response(final_response_data[0]['generated_text'].strip()) return cleaned_final_response except ClientError as error: print(f"Client error: {error}") except BotoCoreError as error: print(f"BotoCore error: {error}") except Exception as error: print(f"Unexpected error: {error}") if __name__ == "__main__": question = "What supplies are required for a knee replacement?" context = { "entities": ["Procedure"], "related_questions": [], "chat_history": [] } answer = handle_question(que`Preformatted text`stion, context) print(f"Final Answer: {answer}") ``` ` Here is what is normally output: `Payload: {"inputs": "system\n\nYou are a Cypher query writing robot who only ever outputs cypher code. Your querries should be general enogh to get all information that may be useful. You only speak in Cypher queries, and should only ever output the Cypher query. You should never output any explanation or any other text other than the Cypher query.\n\nuser\n\nGenerate a Cypher query to that gets enough information that is available answer the following question based on the schema provided. Your cypher query should be as simple and concise as possible. It is better to return more nodes/relationships/information than less, so just write a simple query. \n\nQuestion: What supplies are required for a knee replacement?\nSchema: {\n \"nodes\": {\n \"medical_supply\": [\n {\n \"quantity\": \"Integer\",\n \"usage_rate\": \"String\",\n \"name\": \"String\",\n \"expiration_date\": \"String\"\n }\n ],\n \"hospital_location\": [\n {\n \"address\": \"String\",\n \"cost\": \"Integer\",\n \"city\": \"String\",\n \"name\": \"String\",\n \"delivery_time\": \"Integer\"\n }\n ],\n \"medical_supplier\": [\n {\n \"cost\": \"Integer\",\n \"contact\": \"String\",\n \"name\": \"String\",\n \"delivery_time\": \"Integer\",\n \"email\": \"String\"\n }\n ],\n \"procedure\": [\n {\n \"datetime\": \"Unknown\",\n \"provider\": \"String\",\n \"name\": \"String\"\n }\n ],\n \"doctor\": [\n {\n \"medical_school\": \"String\",\n \"name\": \"String\",\n \"hospital_location_of_practice\": \"String\",\n \"high_school\": \"String\",\n \"previous_work_history\": \"String\",\n \"age\": \"Integer\",\n \"school\": \"String\",\n \"favorite_color\": \"String\",\n \"last_name\": \"String\",\n \"previous_work\": \"String\",\n \"location\": \"String\",\n \"position\": \"String\",\n \"first_name\": \"String\",\n \"username\": \"String\",\n \"height\": \"Integer\",\n \"highschool\": \"String\"\n }\n ],\n \"forecasted_procedure\": [\n {\n \"date\": \"String\",\n \"name\": \"String\",\n \"yhat\": \"Float\",\n \"yhat_lower\": \"Float\",\n \"ds\": \"String\",\n \"yhat_upper\": \"Float\"\n }\n ],\n \"medical_supply_group\": [\n {\n \"name\": \"String\"\n }\n ]\n },\n \"relationships\": [\n {\n \"start_node\": \"medical_supply\",\n \"relationship_type\": \"PROCUREMENT_OPTION\",\n \"end_node\": \"medical_supplier\",\n \"properties\": {\n \"cost\": \"Integer\",\n \"delivery_time\": \"Integer\"\n }\n },\n {\n \"start_node\": \"medical_supply\",\n \"relationship_type\": \"AVAILABLE_AT\",\n \"end_node\": \"hospital_location\",\n \"properties\": {\n \"cost\": \"Integer\",\n \"delivery_time\": \"Integer\"\n }\n },\n {\n \"start_node\": \"procedure\",\n \"relationship_type\": \"REQUIRES\",\n \"end_node\": \"medical_supply_group\",\n \"properties\": {}\n },\n {\n \"start_node\": \"procedure\",\n \"relationship_type\": \"SUPPLY_REQUEST\",\n \"end_node\": \"medical_supply\",\n \"properties\": {\n \"quantity\": \"Integer\",\n \"procedure_date\": \"String\",\n \"timestamp\": \"Integer\"\n }\n },\n {\n \"start_node\": \"doctor\",\n \"relationship_type\": \"PERFORMS\",\n \"end_node\": \"procedure\",\n \"properties\": {}\n },\n {\n \"start_node\": \"forecasted_procedure\",\n \"relationship_type\": \"IMPACTS\",\n \"end_node\": \"medical_supply_group\",\n \"properties\": {}\n },\n {\n \"start_node\": \"forecasted_procedure\",\n \"relationship_type\": \"PRECEDES\",\n \"end_node\": \"forecasted_procedure\",\n \"properties\": {}\n },\n {\n \"start_node\": \"medical_supply_group\",\n \"relationship_type\": \"CONTAINS_SUPPLY\",\n \"end_node\": \"medical_supply\",\n \"properties\": {}\n }\n ]\n}\nEntities: ['Procedure']\nRelated Questions: []\nChat History: []\n\nassistant\n\n", "parameters": {"max_new_tokens": 256, "do_sample": true, "temperature": 0.6, "top_p": 0.9, "return_full_text": false}} Generated Cypher Query: MATCH (p:procedure)-[:REQUIRES]->(ms:medical_supply)-[:PROCUREMENT_OPTION]->(msu:medical_supplier) RETURN p.name AS procedure, ms.name ASupply, msu.name ASupplier, msu.cost ASupplier_cost, msu.delivery_time ASupplier_delivery_time ORDER BY p.name Unexpected error: {code: Neo.ClientError.Statement.SyntaxError} {message: Invalid input 'ASupply': expected an expression, 'FOREACH', ',', 'AS', 'ORDER BY', 'CALL', 'CREATE', 'LOAD CSV', 'DELETE', 'DETACH', 'FINISH', 'INSERT', 'LIMIT', 'MATCH', 'MERGE', 'NODETACH', 'OPTIONAL', 'REMOVE', 'RETURN', 'SET', 'SKIP', 'UNION', 'UNWIND', 'USE', 'WITH' or <EOF> (line 1, column 136 (offset: 135)) "MATCH (p:procedure)-[:REQUIRES]->(ms:medical_supply)-[:PROCUREMENT_OPTION]->(msu:medical_supplier) RETURN p.name AS procedure, ms.name ASupply, msu.name ASupplier, msu.cost ASupplier_cost, msu.delivery_time ASupplier_delivery_time ORDER BY p.name" ^} Final Answer: None `
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 1 comment
-
The issue with the generated Cypher query is that it's using AS incorrectly. In Cypher, AS should be followed by a valid alias for the field, but it seems like ASupply is causing a syntax error because it's not correctly formatted.
The revised query:
MATCH (p:procedure)-[:REQUIRES]->(ms:medical_supply)-[:PROCUREMENT_OPTION]->(msu:medical_supplier)
RETURN p.name AS procedure, ms.name AS supply, msu.name AS supplier, msu.cost AS supplier_cost, msu.delivery_time AS supplier_delivery_time
ORDER BY p.name
Beta Was this translation helpful? Give feedback.