Update nnframes.md (#7808)
This commit is contained in:
parent
a1a1f914bb
commit
633668c882
1 changed files with 5 additions and 5 deletions
|
|
@ -131,7 +131,7 @@ This example trains a model with 3 inputs. And users can use VectorAssembler fro
|
|||
from bigdl.dllib.utils.common import *
|
||||
from bigdl.dllib.nnframes.nn_classifier import *
|
||||
from bigdl.dllib.feature.common import *
|
||||
from bigdl.dllib.keras.objectives import CategoricalCrossEntropy
|
||||
from bigdl.dllib.keras.objectives import SparseCategoricalCrossEntropy
|
||||
from bigdl.dllib.keras.optimizers import Adam
|
||||
from bigdl.dllib.keras.layers import *
|
||||
from bigdl.dllib.nncontext import *
|
||||
|
|
@ -147,9 +147,9 @@ spark = SparkSession\
|
|||
.getOrCreate()
|
||||
|
||||
df = spark.createDataFrame(
|
||||
[(1, 35, 109.0, Vectors.dense([2.0, 5.0, 0.5, 0.5]), 1.0),
|
||||
(2, 58, 2998.0, Vectors.dense([4.0, 10.0, 0.5, 0.5]), 2.0),
|
||||
(3, 18, 123.0, Vectors.dense([3.0, 15.0, 0.5, 0.5]), 1.0)],
|
||||
[(1, 35, 109.0, Vectors.dense([2.0, 5.0, 0.5, 0.5]), 0.0),
|
||||
(2, 58, 2998.0, Vectors.dense([4.0, 10.0, 0.5, 0.5]), 1.0),
|
||||
(3, 18, 123.0, Vectors.dense([3.0, 15.0, 0.5, 0.5]), 0.0)],
|
||||
["user", "age", "income", "history", "label"])
|
||||
|
||||
assembler = VectorAssembler(
|
||||
|
|
@ -171,7 +171,7 @@ merged = merge([flatten, dense1, gru], mode="concat")
|
|||
zy = Dense(2)(merged)
|
||||
|
||||
zmodel = Model([x1, x2, x3], zy)
|
||||
criterion = CategoricalCrossEntropy()
|
||||
criterion = SparseCategoricalCrossEntropy()
|
||||
classifier = NNEstimator(zmodel, criterion, [[1], [2], [2, 2]]) \
|
||||
.setOptimMethod(Adam()) \
|
||||
.setLearningRate(0.1)\
|
||||
|
|
|
|||
Loading…
Reference in a new issue