@@ -26,26 +26,19 @@ def prepare_dataset(
2626 tokenizer_name : str  =  "gpt2" ,
2727 max_length : int  =  1024 ,
2828 cache_dir : Optional [str ] =  None ,
29-  num_proc : int  =  4 
29+  num_proc : int  =  4 ,
30+  text_column : Optional [str ] =  None 
3031) ->  Tuple [PYTORCH_Dataset , Optional [PYTORCH_Dataset ], AutoTokenizer ]:
3132 """ 
3233 Prepare a Hugging Face dataset for training. 
3334
3435 Args: 
35-  dataset_name: Name of the dataset to load (e.g., "wikitext/wikitext-2-raw-v1"). 
36-  tokenizer_name: Name of the tokenizer to use (e.g., "gpt2"). 
37-  max_length: Maximum sequence length for tokenized inputs. 
38-  cache_dir: Directory to cache the dataset. 
39-  num_proc: Number of processes for tokenization. 
40- 
41-  Returns: 
42-  A tuple containing: 
43-  - Train dataset (PYTORCH_Dataset) 
44-  - Validation dataset (Optional[PYTORCH_Dataset]) 
45-  - Tokenizer (AutoTokenizer) 
46- 
47-  Raises: 
48-  DatasetPreparationError: If there is an issue with loading or tokenizing the dataset. 
36+  dataset_name: Name of the dataset to load 
37+  tokenizer_name: Name of the tokenizer to use 
38+  max_length: Maximum sequence length 
39+  cache_dir: Directory to cache the dataset 
40+  num_proc: Number of processes for tokenization 
41+  text_column: Name of the text column (auto-detect if None) 
4942 """ 
5043 try :
5144 # Load tokenizer 
@@ -68,11 +61,37 @@ def prepare_dataset(
6861 except  Exception  as  e :
6962 raise  DatasetPreparationError (f"Failed to load dataset { dataset_name } { str (e )}  )
7063
71-  # Tokenize the dataset 
64+  # Auto-detect text column if not specified 
65+  if  text_column  is  None :
66+  # Common column names for text data 
67+  possible_columns  =  ['text' , 'content' , 'input_text' , 'sentence' , 'document' ]
68+  available_columns  =  dataset ['train' ].column_names 
69+ 70+  # Find the first matching column 
71+  text_column  =  next (
72+  (col  for  col  in  possible_columns  if  col  in  available_columns ),
73+  None 
74+  )
75+ 76+  # If no standard column found, look for any string column 
77+  if  text_column  is  None :
78+  for  col  in  available_columns :
79+  if  isinstance (dataset ['train' ][0 ][col ], str ):
80+  text_column  =  col 
81+  break 
82+ 83+  if  text_column  is  None :
84+  raise  DatasetPreparationError (
85+  f"Could not detect text column. Available columns: { available_columns }  
86+  )
87+ 88+  logger .info (f"Auto-detected text column: { text_column }  )
89+ 90+  # Tokenize function with dynamic column handling 
7291 def  tokenize_function (examples ):
7392 try :
7493 return  tokenizer (
75-  examples ["text" ],
94+  examples [text_column ],
7695 padding = "max_length" ,
7796 truncation = True ,
7897 max_length = max_length ,
@@ -82,11 +101,14 @@ def tokenize_function(examples):
82101 raise  DatasetPreparationError (f"Tokenization failed: { str (e )}  )
83102
84103 logger .info ("Tokenizing dataset" )
104+  # Remove only the text column used for tokenization 
105+  remove_columns  =  [text_column ] if  text_column  in  dataset ['train' ].column_names  else  None 
106+ 85107 tokenized_dataset  =  dataset .map (
86108 tokenize_function ,
87109 batched = True ,
88110 num_proc = num_proc ,
89-  remove_columns = [ "text" ] 
111+  remove_columns = remove_columns 
90112 )
91113
92114 # Convert to PyTorch datasets 
0 commit comments