Update Orca 5min (#6524)
This commit is contained in:
parent
7da102243e
commit
9c99751288
1 changed files with 5 additions and 0 deletions
|
|
@ -50,6 +50,7 @@ Finally, use [sklearn-style Estimator APIs in Orca](distributed-training-inferen
|
||||||
```python
|
```python
|
||||||
from bigdl.orca.learn.tf2.estimator import Estimator
|
from bigdl.orca.learn.tf2.estimator import Estimator
|
||||||
|
|
||||||
|
# Define the NCF model in standard TensorFlow API
|
||||||
def model_creator(config):
|
def model_creator(config):
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
|
||||||
|
|
@ -79,6 +80,8 @@ val_steps = int(test_df.count() / batch_size)
|
||||||
|
|
||||||
est = Estimator.from_keras(model_creator=model_creator, backend="spark",
|
est = Estimator.from_keras(model_creator=model_creator, backend="spark",
|
||||||
config={"embed_dim": 8, "num_users": num_users, "num_items": num_items})
|
config={"embed_dim": 8, "num_users": num_users, "num_items": num_items})
|
||||||
|
|
||||||
|
# Distributed training
|
||||||
est.fit(data=train_df,
|
est.fit(data=train_df,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
epochs=4,
|
epochs=4,
|
||||||
|
|
@ -87,6 +90,8 @@ est.fit(data=train_df,
|
||||||
steps_per_epoch=train_steps,
|
steps_per_epoch=train_steps,
|
||||||
validation_data=test_df,
|
validation_data=test_df,
|
||||||
validation_steps=val_steps)
|
validation_steps=val_steps)
|
||||||
|
|
||||||
|
# Distributed inference
|
||||||
prediction_df = est.predict(test_df,
|
prediction_df = est.predict(test_df,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
feature_cols=['user', 'item'],
|
feature_cols=['user', 'item'],
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue