Update Orca 5min (#6524)

This commit is contained in:
Kai Huang 2022-11-10 10:57:19 +08:00 committed by GitHub
parent 7da102243e
commit 9c99751288

View file

@ -50,6 +50,7 @@ Finally, use [sklearn-style Estimator APIs in Orca](distributed-training-inferen
```python
from bigdl.orca.learn.tf2.estimator import Estimator
# Define the NCF model in standard TensorFlow API
def model_creator(config):
from tensorflow import keras
@ -79,6 +80,8 @@ val_steps = int(test_df.count() / batch_size)
est = Estimator.from_keras(model_creator=model_creator, backend="spark",
config={"embed_dim": 8, "num_users": num_users, "num_items": num_items})
# Distributed training
est.fit(data=train_df,
batch_size=batch_size,
epochs=4,
@ -87,6 +90,8 @@ est.fit(data=train_df,
steps_per_epoch=train_steps,
validation_data=test_df,
validation_steps=val_steps)
# Distributed inference
prediction_df = est.predict(test_df,
batch_size=batch_size,
feature_cols=['user', 'item'],