Skip to content

Commit eabf403

Browse files
authored
add token/revision in source + better error message
1 parent 9a5dc6b commit eabf403

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

pyspark_huggingface/huggingface_source.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ def __init__(self, options):
8080

8181
if "path" not in options or not options["path"]:
8282
raise Exception("You must specify a dataset name.")
83-
83+
8484
kwargs = dict(self.options)
8585
self.dataset_name = kwargs.pop("path")
8686
self.config_name = kwargs.pop("config", None)
8787
self.split = kwargs.pop("split", self.DEFAULT_SPLIT)
88+
self.revision = kwargs.pop("revision", None)
8889
self.streaming = kwargs.pop("streaming", "true").lower() == "true"
90+
self.token = kwargs.pop("token", None)
8991
for arg in kwargs:
9092
if kwargs[arg].lower() == "true":
9193
kwargs[arg] = True
@@ -96,8 +98,12 @@ def __init__(self, options):
9698
kwargs[arg] = ast.literal_eval(kwargs[arg])
9799
except ValueError:
98100
pass
101+
102+
# Raise the right error if the dataset doesn't exist
103+
api = self._get_api()
104+
api.repo_info(self.dataset_name, repo_type="dataset", revision=self.revision)
99105

100-
self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs)
106+
self.builder = load_dataset_builder(self.dataset_name, self.config_name, token=self.token, revision=self.revision, **kwargs)
101107
streaming_dataset = self.builder.as_streaming_dataset()
102108
if self.split not in streaming_dataset:
103109
raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}")
@@ -106,6 +112,11 @@ def __init__(self, options):
106112
if not self.streaming_dataset.features:
107113
self.streaming_dataset = self.streaming_dataset._resolve_features()
108114

115+
def _get_api(self):
116+
from huggingface_hub import HfApi
117+
118+
return HfApi(token=self.token, library_name="pyspark_huggingface")
119+
109120
@classmethod
110121
def name(cls):
111122
return "huggingfacesource"

0 commit comments

Comments
 (0)