Update Orca 5min (#6524)
This commit is contained in:
parent
7da102243e
commit
9c99751288
1 changed files with 5 additions and 0 deletions
|
|
@ -50,6 +50,7 @@ Finally, use [sklearn-style Estimator APIs in Orca](distributed-training-inferen
|
|||
```python
|
||||
from bigdl.orca.learn.tf2.estimator import Estimator
|
||||
|
||||
# Define the NCF model in standard TensorFlow API
|
||||
def model_creator(config):
|
||||
from tensorflow import keras
|
||||
|
||||
|
|
@ -79,6 +80,8 @@ val_steps = int(test_df.count() / batch_size)
|
|||
|
||||
est = Estimator.from_keras(model_creator=model_creator, backend="spark",
|
||||
config={"embed_dim": 8, "num_users": num_users, "num_items": num_items})
|
||||
|
||||
# Distributed training
|
||||
est.fit(data=train_df,
|
||||
batch_size=batch_size,
|
||||
epochs=4,
|
||||
|
|
@ -87,6 +90,8 @@ est.fit(data=train_df,
|
|||
steps_per_epoch=train_steps,
|
||||
validation_data=test_df,
|
||||
validation_steps=val_steps)
|
||||
|
||||
# Distributed inference
|
||||
prediction_df = est.predict(test_df,
|
||||
batch_size=batch_size,
|
||||
feature_cols=['user', 'item'],
|
||||
|
|
|
|||
Loading…
Reference in a new issue