| | import nomic |
| | import pandas as pd |
| | from tqdm import tqdm |
| | from datasets import load_dataset, \ |
| | get_dataset_split_names, \ |
| | get_dataset_config_names, \ |
| | ClassLabel, utils |
| |
|
| | utils.logging.set_verbosity_error() |
| | import pyarrow as pa |
| | from dateutil.parser import parse |
| | import time |
| |
|
| |
|
| | def get_datum_fields(dataset_dict, n_samples = 100, unique_cutoff=20): |
| | |
| | dataset = dataset_dict["first_split_dataset"] |
| | sample = pd.DataFrame(dataset.shuffle(seed=42).take(n_samples)) |
| | features = dataset.features |
| | |
| | indexable_field = None |
| | numeric_fields = [] |
| | string_fields = [] |
| | bool_fields = [] |
| | list_fields = [] |
| | label_fields = [] |
| | categorical_fields = [] |
| | datetime_fields = [] |
| | uncategorized_fields = [] |
| |
|
| | if unique_cutoff < 1: |
| | unique_cutoff = unique_cutoff*len(sample) |
| | |
| | for field, dtype in dataset_dict["schema"].items(): |
| | try: |
| | num_unique = sample[field].nunique() |
| | except: |
| | num_unique = len(sample) |
| | |
| | if dtype == "string": |
| | if num_unique < unique_cutoff: |
| | categorical_fields.append(field) |
| | else: |
| | is_datetime = True |
| | for row in sample: |
| | try: |
| | parse(row[field], fuzzy=False) |
| | except: |
| | is_datetime = False |
| | break |
| | if is_datetime: |
| | datetime_fields.append(field) |
| | else: |
| | string_fields.append(field) |
| |
|
| | elif dtype in ("float"): |
| | numeric_fields.append(field) |
| | |
| | elif dtype in ("int64", "int32", "int16", "int8"): |
| | if features is not None and field in features and isinstance(features[field], ClassLabel): |
| | label_fields.append(field) |
| | elif num_unique < unique_cutoff: |
| | categorical_fields.append(field) |
| | else: |
| | numeric_fields.append(field) |
| | |
| | elif dtype == "bool": |
| | bool_fields.append(field) |
| |
|
| | elif "list" == dtype[0:4]: |
| | list_fields.append(field) |
| |
|
| | else: |
| | uncategorized_fields.append(field) |
| |
|
| | longest_length = 0 |
| | for field in string_fields: |
| | length = 0 |
| | for i in range(len(sample)): |
| | if sample[field][i]: |
| | length += len(str(sample[field][i]).split()) |
| | if length > longest_length: |
| | longest_length = length |
| | indexable_field = field |
| | |
| | return features, \ |
| | numeric_fields, \ |
| | string_fields, \ |
| | bool_fields, \ |
| | list_fields, \ |
| | label_fields, \ |
| | categorical_fields, \ |
| | datetime_fields, \ |
| | uncategorized_fields, \ |
| | indexable_field |
| |
|
| |
|
| | def load_dataset_and_metadata(dataset_name, |
| | config=None, |
| | streaming=True): |
| |
|
| | configs = get_dataset_config_names(dataset_name) |
| | if config is None: |
| | config = configs[0] |
| | |
| | splits = get_dataset_split_names(dataset_name, config) |
| | dataset = load_dataset(dataset_name, config, split = splits[0], streaming=streaming) |
| | head = pa.Table.from_pydict(dataset._head()) |
| | |
| | schema_dict = {field.name: str(field.type) for field in head.schema} |
| |
|
| | dataset_dict = { |
| | "first_split_dataset": dataset, |
| | "name": dataset_name, |
| | "config": config, |
| | "splits": splits, |
| | "schema": schema_dict, |
| | "head": head |
| | } |
| |
|
| | return dataset_dict |
| |
|
| |
|
| | def upload_dataset_to_atlas(dataset_dict, |
| | atlas_api_token: str, |
| | project_name = None, |
| | unique_id_field_name=None, |
| | indexed_field = None, |
| | modality=None, |
| | organization_name=None, |
| | wait_for_map=True, |
| | datum_limit=30000): |
| | nomic.login(atlas_api_token) |
| |
|
| | if modality is None: |
| | modality = "text" |
| |
|
| | if unique_id_field_name is None: |
| | unique_id_field_name = "atlas_datum_id" |
| |
|
| | if project_name is None: |
| | project_name = dataset_dict["name"].replace("/", "--") + "--hf-atlas-map" |
| |
|
| | desc = f"Config: {dataset_dict['config']}" |
| |
|
| | features, \ |
| | numeric_fields, \ |
| | string_fields, \ |
| | bool_fields, \ |
| | list_fields, \ |
| | label_fields, \ |
| | categorical_fields, \ |
| | datetime_fields, \ |
| | uncategorized_fields, \ |
| | indexable_field = get_datum_fields(dataset_dict) |
| |
|
| | if indexed_field is None: |
| | indexed_field = indexable_field |
| |
|
| | topic_label_field = None |
| | if modality == "embedding": |
| | topic_label_field = indexed_field |
| | indexed_field = None |
| |
|
| |
|
| | easy_fields = string_fields + bool_fields + list_fields + categorical_fields |
| | |
| | proj = nomic.AtlasProject(name=project_name, |
| | modality=modality, |
| | unique_id_field=unique_id_field_name, |
| | organization_name=organization_name, |
| | description=desc, |
| | reset_project_if_exists=True) |
| | |
| | colorable_fields = ["split"] |
| | |
| | batch_size = 1000 |
| | batched_texts = [] |
| |
|
| | allow_upload = True |
| |
|
| | for split in dataset_dict["splits"]: |
| |
|
| | if not allow_upload: |
| | break |
| |
|
| | dataset = load_dataset(dataset_dict["name"], dataset_dict["config"], split = split, streaming=True) |
| |
|
| | for i, ex in tqdm(enumerate(dataset)): |
| | if i % 10000 == 0: |
| | time.sleep(2) |
| | if i == datum_limit: |
| | print("Datum upload limited to 30,000 points. Stopping upload...") |
| | allow_upload = False |
| | break |
| |
|
| | data_to_add = {"split": split, unique_id_field_name: f"{split}_{i}"} |
| |
|
| | for field in numeric_fields: |
| | data_to_add[field] = ex[field] |
| |
|
| | for field in easy_fields: |
| | val = "" |
| | if ex[field]: |
| | val = str(ex[field]) |
| | data_to_add[field] = val |
| |
|
| | for field in datetime_fields: |
| | try: |
| | data_to_add[field] = parse(ex[field], fuzzy=False) |
| | except: |
| | data_to_add[field] = None |
| |
|
| | for field in label_fields: |
| | label_name = "" |
| | if ex[field] is not None: |
| | index = ex[field] |
| | |
| | if index != -1: |
| | label_name = features[field].names[ex[field]] |
| | data_to_add[field] = str(ex[field]) |
| | data_to_add[field + "_name"] = label_name |
| | colorable_fields.add(field + "_name") |
| |
|
| | for field in list_fields: |
| | list_str = "" |
| | if ex[field]: |
| | try: |
| | list_str = str(ex[field]) |
| | except: |
| | continue |
| | data_to_add[field] = list_str |
| |
|
| | batched_texts.append(data_to_add) |
| |
|
| | if len(batched_texts) >= batch_size: |
| | proj.add_text(batched_texts) |
| | batched_texts = [] |
| |
|
| | if len(batched_texts) > 0: |
| | proj.add_text(batched_texts) |
| | |
| | colorable_fields = colorable_fields + \ |
| | categorical_fields + label_fields + bool_fields + datetime_fields |
| |
|
| | projection = proj.create_index(name=project_name + " index", |
| | indexed_field=indexed_field, |
| | colorable_fields=colorable_fields, |
| | topic_label_field = topic_label_field, |
| | build_topic_model=True) |
| | |
| | if wait_for_map: |
| | with proj.wait_for_project_lock(): |
| | time.sleep(1) |
| | |
| | return projection.map_link |
| |
|
| | |
| | if __name__ == "__main__": |
| | dataset_name = "databricks/databricks-dolly-15k" |
| | |
| | project_name = "huggingface_auto_upload_test-dolly-15k" |
| |
|
| | dataset_dict = load_dataset_and_metadata(dataset_name) |
| | api_token = "ODdPKqJHYci4Gq4jnCC5-VR0L-rnIdfIy-6djgC4CTPCJ" |
| | print(upload_dataset_to_atlas(dataset_dict, api_token, project_name=project_name)) |
| |
|