179179"""
180180from enum import Enum , auto
181181import graph
182+ import random
183+ import urllib .request
184+ import itertools
185+ import pickle
186+ from six import string_types
182187
183188
184189class Tokenization (Enum ):
@@ -220,7 +225,7 @@ def __init__(self, level, tokenization=None):
220225 f"be of acceptable type" )
221226 # instance attributes
222227 self ._level = level
223- self ._tokenization = tokenization
228+ self ._mode = tokenization
224229 # initialize our Markov Chain
225230 self .chain = graph .Graph ()
226231 # initialize our probability sum container to hold the total sum of
@@ -234,11 +239,11 @@ def level(self):
234239 return self ._level
235240
236241 @property
237- def tokenization (self ):
242+ def mode (self ):
238243 """Getter that returns the given RandomWriter's tokenization
239244 attribute.
240245 """
241- return self ._tokenization
246+ return self ._mode
242247
243248 def add_chain (self , data ):
244249 """Function to add a new state to our Markov Chain, uses our Graph's
@@ -276,8 +281,29 @@ def add_conn(self, source, dest, token):
276281 """
277282 # add an edge from source's vert obj to dest's vert obj
278283 self .chain [source ].add_edge (self .chain [dest ], token )
279- # add 1 to the number of outgoing edges from the source destination
280- self ._incr_prob_sum (source )
284+ 285+ def _choose_rand_edge (self , vertex ):
286+ """Randomly traversing the Markov Chain that we have already
287+ constructed with our train algorithm, output a single token
288+
289+ Use our random graph traversal algorithm where we choose a number
290+ between 1 the sum of the total weights, if that number is less than
291+ the state we are currently evaluating, then we have found our next
292+ path, otherwise we subtract that number from our randomly chosen
293+ number and continue to evaluate the new difference against the next
294+ state
295+ """
296+ # index into our total probability sum taken from all of the
297+ rand_val = random .randint (0 , self .prob_sums [vertex .data ])
298+ # iterate over our current outgoing edges
299+ for choice in vertex .outgoing_edges :
300+ curr_edge = vertex .get_edge (choice )
301+ # current edge's weight is less than the random val, success
302+ if rand_val <= curr_edge .weight :
303+ return curr_edge
304+ # otherwise subtract that edge's weight from rand_val and continue
305+ else :
306+ rand_val -= curr_edge .weight
281307
282308 def generate (self ):
283309 """Generate tokens using the model.
@@ -291,8 +317,25 @@ def generate(self):
291317 new starting node at random and continuing.
292318
293319 """
294- # TODO: GENERATOR OBJ FOR ONCE WE HAVE OUR GRAPH
295- raise NotImplementedError
320+ # randomly select a starting state until found one w/ outgoing edges
321+ state = random .choice (list (self .chain .vertices ))
322+ vertex = self .chain [state ]
323+ # ensure that we at least pick a starting state w/ outgoing edges
324+ while not vertex .outgoing_edges :
325+ state = random .choice (list (self .chain .vertices ))
326+ vertex = self .chain [state ]
327+ 328+ # continue to traverse and generate output indefinitely
329+ while True :
330+ # choose an edge weighted-randomly
331+ curr_edge = self ._choose_rand_edge (vertex )
332+ yield curr_edge .token
333+ # go to the next vertex, taking the edge we just yielded
334+ vertex = curr_edge .dest_vertex
335+ # handle case where vertex has no outgoing edges
336+ while not vertex .outgoing_edges :
337+ state = random .choice (list (self .chain .vertices ))
338+ vertex = self .chain [state ]
296339
297340 def generate_file (self , filename , amount ):
298341 """Write a file using the model.
@@ -310,8 +353,21 @@ def generate_file(self, filename, amount):
310353
311354 Make sure to open the file in the appropriate mode.
312355 """
313- # TODO: OUTPUT GENERATOR YIELDED RANDOMIZED TOKENS TO FILE
314- raise NotImplementedError
356+ # open the file in byte mode if byte tokenized
357+ fi = open (filename , "wb" ) if self .mode is Tokenization .byte else \
358+ open (filename , "w" , encoding = "utf-8" )
359+ # only get the first "amount" elements in our generated data
360+ for token in itertools .islice (self .generate (), amount ):
361+ # make sure we correctly format our output
362+ if self .mode is Tokenization .word :
363+ fi .write (token + " " )
364+ elif self .mode is Tokenization .none or self .mode is None :
365+ fi .write (str (token )+ " " )
366+ elif self .mode is Tokenization .byte :
367+ fi .write (bytes ([token ]))
368+ else :
369+ fi .write (str (token ))
370+ fi .close ()
315371
316372 def save_pickle (self , filename_or_file_object ):
317373 """Write this model out as a Python pickle.
@@ -324,7 +380,18 @@ def save_pickle(self, filename_or_file_object):
324380 in binary mode.
325381
326382 """
327- raise NotImplementedError
383+ # file object
384+ if hasattr (filename_or_file_object , "read" ):
385+ # save this RandomWriter to a pickle
386+ pickle .dump (self , filename_or_file_object )
387+ # file name
388+ elif isinstance (filename_or_file_object , string_types ):
389+ # open the file in the correct mode
390+ with open (filename_or_file_object , "wb" ) as fi :
391+ pickle .dump (self , fi )
392+ else :
393+ raise ValueError (f"Error: { filename_or_file_object } is not a "
394+ f"filename or file object" )
328395
329396 @classmethod
330397 def load_pickle (cls , filename_or_file_object ):
@@ -341,7 +408,18 @@ def load_pickle(cls, filename_or_file_object):
341408 in binary mode.
342409
343410 """
344- raise NotImplementedError
411+ # file object
412+ if hasattr (filename_or_file_object , "read" ):
413+ # save this RandomWriter to a pickle
414+ pickle .load (filename_or_file_object )
415+ # file name
416+ elif isinstance (filename_or_file_object , string_types ):
417+ # open the file in the correct mode
418+ with open (filename_or_file_object , "rb" ) as fi :
419+ pickle .load (fi )
420+ else :
421+ raise ValueError (f"Error: { filename_or_file_object } is not a "
422+ f"filename or file object" )
345423
346424 def train_url (self , url ):
347425 """Compute the probabilities based on the data downloaded from url.
@@ -356,7 +434,21 @@ def train_url(self, url):
356434 Do not duplicate any code from train_iterable.
357435
358436 """
359- raise NotImplementedError
437+ # Ensure that the mode is correct
438+ if self .mode is Tokenization .none or self .mode is None :
439+ raise ValueError ("Error: this type of training is only supported "
440+ "if the tokenization mode is not none" )
441+ 442+ # Open the url and read in the data
443+ with urllib .request .urlopen (url ) as f :
444+ # if byte mode, we don't have to decode
445+ if self .mode is Tokenization .byte :
446+ data = f .read ()
447+ # otherwise, make sure we decode as utf-8
448+ else :
449+ data = f .read ().decode ()
450+ # train the data
451+ self .train_iterable (data )
360452
361453 def _gen_tokenized_data (self , data , size ):
362454 """Helper function to generate tokens of proper length based on the
@@ -370,7 +462,7 @@ def _gen_tokenized_data(self, data, size):
370462 else if tokenization is byte then data is a bytestream
371463
372464 NOTE: code taken from Arthur Peters' windowed() function in
373- final_tests.py, Thanks Mr. Peters!
465+ final_tests.py
374466
375467 TODO: handle k = 0 level case
376468 """
@@ -389,41 +481,7 @@ def _gen_tokenized_data(self, data, size):
389481 window .append (elem )
390482 # if the window has reached specified size, yield the proper state
391483 if len (window ) == size :
392- # tokenize by string
393- if self .tokenization is Tokenization .character or \
394- self .tokenization is Tokenization .word :
395- yield "" .join (window )
396- # tokenize by byte
397- elif self .tokenization is Tokenization .byte :
398- yield b"" .join (window )
399- # simply yield another iterable
400- else :
401- yield tuple (window )
402- 403- def _data_type_check (self , data ):
404- """Helper function to make sure that the data is in the correct form
405- for this RandomWriter's Tokenization
406-
407- If the tokenization mode is none, data must be an iterable. If
408- the tokenization mode is character or word, then data must be
409- a string. Finally, if the tokenization mode is byte, then data
410- must be a bytes. If the type is wrong raise TypeError.
411- """
412- # if in character or word tokenization, data must be a str
413- if self .tokenization is Tokenization .character or self .tokenization \
414- is Tokenization .word :
415- return isinstance (data , str )
416- # if in byte tokenization, data must by raw bytes
417- elif self .tokenization is Tokenization .byte :
418- return isinstance (data , bytes )
419- # if in none tokenization, data must be an iterable
420- elif self .tokenization is Tokenization .none or self .tokenization is \
421- Tokenization .none .value :
422- return hasattr (data , '__iter__' )
423- # something went wrong with the constructor
424- else :
425- raise TypeError ("Error: this RandomWriter does not have a proper "
426- "tokenization" )
484+ yield tuple (window )
427485
428486 def _build_markov_chain (self , states ):
429487 """Helper function to help build the Markov Chain graph and compute
@@ -456,6 +514,8 @@ def _build_markov_chain(self, states):
456514 self .add_chain (new_state )
457515 # add a connection from the old chain to the new chain
458516 self .add_conn (old_state , new_state , new_state [- 1 ])
517+ # update the old state's total probability sum
518+ self ._incr_prob_sum (old_state )
459519 # iterate to the next state
460520 old_state = new_state
461521
@@ -466,19 +526,9 @@ def train_iterable(self, data):
466526 simpler to store it don't worry about it. For most input types
467527 you will not need to store it.
468528 """
469- # type check on the input data
470- if not self ._data_type_check (data ):
471- raise TypeError (f"Error: data is not in correct form for this "
472- f"RW's mode -> { self .tokenization } " )
473- 474529 # if we are in word tokenization then split data along white spaces
475- if self .tokenization is Tokenization .word :
530+ if self .mode is Tokenization .word :
476531 states = self ._gen_tokenized_data (data .split (), self .level )
477- # if we are in byte tokenization then split data by byte
478- elif self .tokenization is Tokenization .byte :
479- states = self ._gen_tokenized_data ((data [i :i + 1 ] for i in range (
480- len (data ))), self .level )
481- # otherwise, iterate over data normally to create new gen of states
482532 else :
483533 states = self ._gen_tokenized_data (data , self .level )
484534 # build our Markov Chain based on the generator we've constructed from
0 commit comments