import json
import os
import subprocess
from datasets import get_dataset_config_names, load_dataset
from squadds.core.globals import *
from squadds.core.utils import (
compare_schemas,
delete_HF_cache,
get_config_schema,
get_entire_schema,
get_type,
validate_types,
)
from squadds.database.contributor_env import build_contributor_record, load_contributor_environment
from squadds.database.contributor_file_ops import (
append_entries_to_dataset_file,
load_contribution_from_json_file,
load_sweep_entries_from_json_prefix,
validate_sweep_entries,
)
from squadds.database.contributor_records import (
add_sim_result_entry,
build_contribution_payload,
build_empty_contribution_state,
merge_contributor_notes,
validate_required_structure,
)
from squadds.database.contributor_schema import validate_design_payload, validate_sim_setup_payload
from squadds.database.contributor_validation import get_nested_value, summarize_content_differences
"""
! TODO:
* Inputs the config/system data
* required schema generated based on the config/system data
"""
[docs]
class ExistingConfigData:
"""
Represents an existing configuration data object.
Attributes:
config (str): The name of the configuration.
sim_results (dict): A dictionary containing simulation results.
design (dict): A dictionary containing design options and the design tool.
sim_options (dict): A dictionary containing simulation setup options.
units (set): A set containing the units used in the simulation results.
notes (dict): A dictionary containing additional notes.
ref_entry (dict): A dictionary containing the reference entry.
contributor (dict): A dictionary containing contributor information.
entry (dict): A dictionary containing the contribution data.
local_repo_path (str): The local repository path.
sweep_data (list): A list containing sweep data.
Methods:
_validate_config_name(): Validates the configuration name.
get_config_schema(): Retrieves the schema for the given configuration name.
show_config_schema(): Prints the schema for the given configuration name.
_supported_config_names(): Retrieves the supported configuration names.
show(): Prints the contribution data.
__set_contributor_info(): Sets the contributor information.
get_contributor_info(): Retrieves the contributor information.
add_sim_result(result_name, result_value, unit): Adds a simulation result.
add_sim_setup(sim_setup): Adds simulation setup options to the contribution.
add_design(design): Adds a design to the contribution.
add_design_v0(design): Adds a design to the contribution (version 0).
to_dict(): Converts the contribution data to a dictionary.
clear(): Clears the contribution data.
add_notes(notes): Adds notes to the contribution.
validate_structure(actual_structure): Validates the structure of the contributor object.
_validate_structure(): Validates the structure of the contributor object.
validate_types(data): Validates the types of the data.
_validate_types(): Validates the types of the data.
_validate_content_v0(): Validates the content of the contribution against the dataset schema.
"""
def __init__(self, config=""):
self.__repo_name = "SQuADDS/SQuADDS_DB"
self.config = config
self._validate_config_name()
load_contributor_environment()
state = build_empty_contribution_state()
self.sim_results = state["sim_results"]
self.design = state["design"]
self.sim_options = state["sim_options"]
self.units = state["units"]
self.notes = state["notes"]
self.ref_entry = {}
self.__set_contributor_info()
self.entry = self.to_dict()
self.__isValidated = False
self.local_repo_path = ""
self.sweep_data = []
[docs]
def _validate_config_name(self):
"""
Validates the config name against the supported config names.
Raises:
ValueError: If the config name is invalid.
"""
configs = self._supported_config_names()
if self.config not in configs:
raise ValueError(f"Invalid config name: {self.config}. Supported config names: {configs}")
[docs]
def get_config_schema(self):
"""
Connects to the repository with the given configuration name. Chooses the first entry from the config dataset and extracts the schema.
Returns:
A dictionary containing the schema for the given configuration name.
"""
# get the first entry
config_dataset = load_dataset(self.__repo_name, self.config)
entry = config_dataset["train"][0]
self.ref_entry = entry
schema = get_config_schema(entry)
return schema # Return the schema as a dictionary
[docs]
def show_config_schema(self):
"""
Connects to the repository with the given configuration name. Chooses the first entry from the config dataset and extracts the schema.
Returns:
None
"""
# get the first entry
config_dataset = load_dataset(self.__repo_name, self.config)
entry = config_dataset["train"][0]
schema = get_config_schema(entry)
print(json.dumps(schema, indent=2))
[docs]
def _supported_config_names(self):
"""
Retrieves the supported configuration names from the repository.
Returns:
A list of supported configuration names.
"""
delete_HF_cache()
configs = get_dataset_config_names(self.__repo_name, download_mode="force_redownload")
return configs
# method that returns the contribution data in a dictionary format
[docs]
def show(self):
"""
Print the contribution data in a pretty format.
Args:
None
Returns:
None
"""
# pretty print the contribution data
print(json.dumps(self.to_dict(), indent=4))
def __set_contributor_info(self):
self.contributor = build_contributor_record()
[docs]
def get_contributor_info(self):
"""
Returns the contributor information.
Returns:
str: The contributor information.
"""
return self.contributor
[docs]
def add_sim_result(self, result_name, result_value, unit):
"""
Add a simulation result to the contributor.
Args:
result_name (str): The name of the simulation result.
result_value (float): The value of the simulation result.
unit (str): The unit of measurement for the simulation result.
Returns:
None
"""
self.sim_results, self.units = add_sim_result_entry(
self.sim_results, self.units, result_name, result_value, unit
)
[docs]
def add_sim_setup(self, sim_setup):
"""
Adds simulation setup options to the contribution.
Args:
sim_setup (dict): A dictionary containing simulation setup options that match the configs schema.
"""
# Retrieve the schema for simulation options
schema = self.get_config_schema()
validate_sim_setup_payload(sim_setup, schema.get("sim_options", {}), get_type)
# All checks passed, add the simulation setup options
self.sim_options.update(sim_setup)
[docs]
def add_design(self, design):
"""
Adds a design to the contribution.
Args:
design (dict): A dictionary containing design options and the design tool.
"""
# Retrieve the schema for design
schema = self.get_config_schema()
validate_design_payload(design, schema.get("design", {}).get("design_options", {}), get_type)
# All checks passed, add the design options and tool
self.design.update(design)
[docs]
def add_design_v0(self, design):
"""
Adds a design to the contribution.
Args:
design (dict): A dictionary containing design options and the design tool.
"""
# Retrieve the schema for design
schema = self.get_config_schema()
validate_design_payload(
design,
schema.get("design", {}).get("design_options", {}),
get_type,
require_design_tool=True,
)
# All checks passed, add the design options and tool
self.design.update(design)
[docs]
def to_dict(self):
"""
Converts the Contributor object to a dictionary.
Returns:
dict: A dictionary representation of the Contributor object.
"""
return build_contribution_payload(
self.design,
self.sim_options,
self.sim_results,
self.contributor,
self.notes,
self.units,
)
[docs]
def clear(self):
"""
Clears the contribution data.
"""
state = build_empty_contribution_state()
self.sim_results = state["sim_results"]
self.design = state["design"]
self.sim_options = state["sim_options"]
self.units = state["units"]
self.notes = state["notes"]
self.__isValidated = False
[docs]
def add_notes(self, notes=None):
"""
Adds notes to the contribution.
Args:
notes (dict): A dictionary containing notes.
"""
self.notes = merge_contributor_notes(self.notes, notes)
[docs]
def validate_structure(self, actual_structure):
"""
Validates the structure of the contributor object.
Args:
actual_structure (dict): The actual structure of the contributor object.
Raises:
ValueError: If any required key or sub-key is missing in the actual structure.
"""
expected_structure = self.get_config_schema()
validate_required_structure(actual_structure, expected_structure)
print("Structure validated successfully....")
[docs]
def _validate_structure(self):
"""
Validates the structure of the contributor object.
Raises:
ValueError: If any required key or sub-key is missing in the actual structure.
"""
expected_structure = self.get_config_schema()
actual_structure = self.to_dict()
validate_required_structure(actual_structure, expected_structure)
print("Structure validated successfully....")
[docs]
def validate_types(self, data):
"""
Args:
data (dict): The data to be validated.
Validates the types of the data using the schema defined in the config.
"""
schema = self.get_config_schema()
validate_types(data, schema)
print("Types validated successfully....")
[docs]
def _validate_types(self):
"""
Validates the types of the data using the schema defined in the config.
"""
schema = self.get_config_schema()
data = self.to_dict()
validate_types(data, schema)
print("Types validated successfully....")
[docs]
def _validate_content_v0(self):
"""
Validates the content of the contribution against the dataset schema.
"""
data = self.to_dict()
ref = self.ref_entry
# print data and ref nicely json
# print(f"Data: {json.dumps(data, indent=2)}")
# print(f"Ref: {json.dumps(ref, indent=2)}")
# Validate 'sim_options.setup' and 'design.design_options'
for key in ["design", "sim_options"]:
sub_key = "setup" if key == "sim_options" else "design_options"
data_schema = get_entire_schema(data[key][sub_key])
expected_schema = get_entire_schema(ref[key][sub_key])
print(f"Key: {key}, Sub-key: {sub_key}")
print(f"Data schema: {json.dumps(data_schema, indent=2)}")
print(f"Expected schema: {json.dumps(expected_schema, indent=2)}")
if data_schema != expected_schema:
raise ValueError(
f"Structure mismatch in '{key}.{sub_key}'. Expected: {expected_schema}, Got: {data_schema}"
)
[docs]
def validate_content(self, data):
"""
Args:
data (dict): The data to be validated.
Validates the content of the contribution against the dataset schema.
"""
return None
def _validate_content(self):
"""
Validates the content of the contribution against the dataset schema.
"""
data = self.to_dict()
ref = self.ref_entry
mismatched_keys, missing_keys = summarize_content_differences(data, ref)
if mismatched_keys:
print(
"\nMismatched keys found. These keys are present in both dictionaries but have values of different types:\n"
)
for key in mismatched_keys:
print(
f"Key: {key}, data type in 'data': {type(get_nested_value(data, key))}, data type in 'ref': {type(get_nested_value(ref, key))}"
)
if missing_keys:
print("\nMissing keys found. These keys are present in one dictionary but not the other:\n")
for key in missing_keys:
if get_nested_value(data, key) is not None:
print(f"Key: {key} is missing in 'ref'")
else:
print(f"Key: {key} is missing in 'data'")
# return common_keys, mismatched_keys, missing_keys
def _validate_content_v1(self):
"""
Validates the content of the contribution against the dataset schema.
"""
data = self.to_dict()
ref = self.ref_entry
for key in ["design", "sim_options"]:
sub_key = "setup" if key == "sim_options" else "design_options"
data_schema = get_entire_schema(data[key][sub_key])
expected_schema = get_entire_schema(ref[key][sub_key])
print(f"Key: {key}, Sub-key: {sub_key}")
# print(f"Data schema: {json.dumps(data_schema, indent=2)}")
# print(f"Expected schema: {json.dumps(expected_schema, indent=2)}")
compare_schemas(data_schema, expected_schema, f"{key}.{sub_key}.")
print("Content validation passed.")
[docs]
def validate(self):
"""
Validates the contribution by performing various checks.
Raises:
Exception: If any validation check fails.
"""
# Perform all validation checks
# if no errors then set isValidated to True
if not self.is_validated:
try:
self._validate_structure()
self._validate_types()
self._validate_content()
self.__isValidated = True
except Exception as e:
print("Validation failed.")
raise e
else:
print("This contribution has already been validated.")
[docs]
def validate_sweep(self):
"""
Validates the sweep data by performing structure, type, and content validation on each entry.
Raises:
Exception: If the validation fails.
Returns:
None
"""
if not self.is_validated:
try:
validate_sweep_entries(
self.sweep_data,
validate_structure_fn=self.validate_structure,
validate_types_fn=self.validate_types,
validate_content_fn=self.validate_content,
)
self.__isValidated = True
except Exception as e:
print("Validation failed.")
raise e
else:
print("This contribution has already been validated.")
@property
def invalidate(self):
"""
Invalidates the contributor by setting the isValidated flag to False.
"""
self.__isValidated = False
[docs]
def update_repo(self, path_to_repo):
"""
Updates the repository at the specified path.
Args:
path_to_repo (str): The path to the repository.
Raises:
subprocess.CalledProcessError: If the git commands fail.
"""
original_cwd = os.getcwd()
try:
# Check if data is validated
if not self.is_validated:
raise ValueError("Data must be validated before updating the repository.")
# Create the path to the repo if it doesn't exist
if not os.path.exists(path_to_repo):
os.makedirs(path_to_repo)
# Check if the repo exists by looking for .git file in the path_to_repo + "SQuADDS_DB" directory
if os.path.exists(path_to_repo + "/" + self.__repo_name.split("/")[-1]):
# Pull the latest changes
os.chdir(path_to_repo + "/" + self.__repo_name.split("/")[-1])
subprocess.run(["git", "pull"], check=True)
else:
print(f"Cloning dataset repository from to {path_to_repo}...")
os.chdir(path_to_repo)
dataset_endpoint = f"git@hf.co:datasets/{self.__repo_name}"
# Clone the repo
# subprocess.run(["git", "clone", dataset_endpoint], check=True)
subprocess.run(
[
"git",
"-c",
"core.sshCommand=ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no",
"clone",
dataset_endpoint,
],
check=True,
)
# Create a new branch and checkout to it
# uploader_name = self.contributor['uploader'].replace(" ", "")
# uid = self.contributor['date_created'].replace(" ", "")
# branch_name = f"add_{self.config}_{uploader_name}_{uid}"
# subprocess.run(["git", "checkout", "-b", branch_name], check=True)
finally:
# Revert to the original current working directory
os.chdir(original_cwd)
[docs]
def update_db(self, path_to_repo, is_sweep=False):
"""
Updates the local repository with the validated data.
Args:
path_to_repo (str): The path to the local repository.
Raises:
ValueError: If the data has not been validated.
"""
if not is_sweep:
if not self.is_validated:
raise ValueError("Data must be validated before updating the repository.")
# update the local repo
os.chdir(path_to_repo + "/" + self.__repo_name.split("/")[-1])
dataset_file = f"{self.config}.json"
append_entries_to_dataset_file(dataset_file, [self.to_dict()])
print(f"Data added to {dataset_file} successfully.")
else:
if not self.is_validated:
raise ValueError("Data must be validated before updating the repository.")
# update the local repo
os.chdir(path_to_repo + "/" + self.__repo_name.split("/")[-1])
dataset_file = f"{self.config}.json"
append_entries_to_dataset_file(dataset_file, self.sweep_data)
print(f"Data added to {dataset_file} successfully.")
[docs]
def upload_to_HF(self, path_to_repo):
"""
Uploads validated data to the specified repository.
Args:
path_to_repo (str): The path to the repository.
Raises:
ValueError: If the data has not been validated.
subprocess.CalledProcessError: If the git commands fail.
Returns:
None
"""
if not self.is_validated:
raise ValueError("Data must be validated before updating the repository.")
# navigate to the repo
os.chdir(path_to_repo + "/" + self.__repo_name.split("/")[-1])
# create a commit message based on the contributor info
commit_message = f"Add {self.config} data from {self.contributor['group']} group by {self.contributor['uploader']} on {self.contributor['date_created']}"
uploader_name = self.contributor["uploader"].replace(" ", "")
uid = self.contributor["date_created"].replace(" ", "")
branch_name = f"add_{self.config}_{uploader_name}_{uid}"
try:
# Commit and push changes
subprocess.run(["git", "add", f"{self.config}.json"], check=True)
subprocess.run(["git", "commit", "-m", commit_message], check=True)
except subprocess.CalledProcessError as e:
print(f"Failed to commit changes to {self.config}.json")
raise e
try:
# create upstream branch
os.environ["GITHUB_TOKEN"] = os.getenv("GITHUB_TOKEN")
subprocess.run(["git", "push", "--set-upstream", "origin", branch_name], check=True)
except subprocess.CalledProcessError as e:
print(f"Failed to create upstream branch for {self.config}.json")
raise e
try:
# Push changes - ensure you have the necessary permissions and authentication set up
subprocess.run(["git", "push"], check=True)
except subprocess.CalledProcessError as e:
print(f"Failed to push changes to {self.config}.json")
raise e
[docs]
def from_json(self, json_file, is_sweep=False):
"""
Loads a contribution from a JSON file.
Args:
json_file (str): The path to the JSON file.
is_sweep (bool): True if the contribution is a sweep, False otherwise.
"""
if not is_sweep:
file_path = os.path.abspath(json_file)
if not os.path.exists(file_path):
raise ValueError(f"File not found: {file_path}")
data = load_contribution_from_json_file(file_path)
self.design = data["design"]
self.sim_options = data["sim_options"]
self.sim_results = data["sim_results"]
self.__set_contributor_info()
try:
self.notes = data["notes"]
except KeyError:
pass
print("Contribution loaded successfully.")
else:
sweep_data = load_sweep_entries_from_json_prefix(json_file, self.get_contributor_info())
if not sweep_data:
raise ValueError(f"No sweep files found for prefix: {os.path.abspath(json_file)}")
self.sweep_data.extend(sweep_data)
print("Sweep data loaded successfully.")
@property
def is_validated(self):
"""
Returns True if the contribution is validated, False otherwise.
Returns:
bool: True if the contribution is validated, False otherwise.
"""
return self.__isValidated
[docs]
def contribute(self, path_to_repo, is_sweep=False):
"""
Contributes to the repository by updating the local repo, updating the database, and uploading to HF.
Args:
path_to_repo (str): The path to the repository.
is_sweep (bool): True if the contribution is a sweep, False otherwise.
Returns:
None
"""
if not self.is_validated:
raise ValueError("Data must be validated before contributing.")
self.update_repo(path_to_repo)
self.update_db(path_to_repo, is_sweep)
# self.upload_to_HF(path_to_repo)
print("Contribution ready for PR")
[docs]
def submit(self):
"""
Sends the data and the config name to a remote server.
"""
raise NotImplementedError("This method is not implemented yet.")