Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Study Keras: weight initialization

Yang Yang(Tony) edited this page Apr 23, 2019 · 1 revision
  • TensorFlow Version: 2.0
  • Keras Model API: subclassing Keras.Model

Sample Code

class DNNClassifier(tf.keras.Model):
 def __init__(self, feature_columns, hidden_units, n_classes):
 super(DNNClassifier, self).__init__()
 self.feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
 self.hidden_layers = []
 for hidden_unit in hidden_units:
 self.hidden_layers.append(tf.keras.layers.Dense(hidden_unit))
 self.prediction_layer = tf.keras.layers.Dense(n_classes, activation='softmax')
 def call(self, inputs):
 x = self.feature_layer(inputs)
 for hidden_layer in self.hidden_layers:
 x = hidden_layer(x)
 return self.prediction_layer(x)

Stacktrace on model.fit

 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(86)<module>()
-> model.fit(train_ds, validation_data=val_ds, epochs=model.default_training_epochs(), verbose=0)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(794)fit()
-> initial_epoch=initial_epoch)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1519)fit_generator()
-> steps_name='steps_per_epoch')
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(257)model_iteration()
-> batch_outs = batch_function(*batch_data)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1242)train_on_batch()
-> extract_tensors_from_dataset=True)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2507)_standardize_user_data()
-> self._set_inputs(cast_inputs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py(456)_method_wrapper()
-> result = method(self, *args, **kwargs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2779)_set_inputs()
-> outputs = self.call(inputs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(62)call()
-> x = hidden_layer(x)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(594)__call__()
-> self._maybe_build(inputs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(1713)_maybe_build()
-> self.build(input_shapes)
> /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py(961)build()
-> dtype = dtypes.as_dtype(self.dtype or K.floatx())

Stacktrace on model.predict

 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(90)<module>()
-> model.predict(test_ds)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1147)predict()
-> callbacks=callbacks)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(257)model_iteration()
-> batch_outs = batch_function(*batch_data)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(531)predict_on_batch()
-> return model.predict_on_batch(x)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1365)predict_on_batch()
-> x, extract_tensors_from_dataset=True)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2507)_standardize_user_data()
-> self._set_inputs(cast_inputs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py(456)_method_wrapper()
-> result = method(self, *args, **kwargs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2779)_set_inputs()
-> outputs = self.call(inputs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(62)call()
-> x = hidden_layer(x)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(594)__call__()
-> self._maybe_build(inputs)
 /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(1713)_maybe_build()
-> self.build(input_shapes)
> /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py(961)build()
-> dtype = dtypes.as_dtype(self.dtype or K.floatx())

Conclusion

Main logics at Model._set_inputs. It creates symbolic tensors and passes them to the layer.__call__. Inside layer.__call__, if the input is symbolic tensors, layer.__call__ will invoke layer.build.

Clone this wiki locally

AltStyle によって変換されたページ (->オリジナル) /