4000-520-616
欢迎来到免疫在线!(蚂蚁淘生物旗下平台)  请登录 |  免费注册 |  询价篮
主营:原厂直采,平行进口,授权代理(蚂蚁淘为您服务)
咨询热线电话
4000-520-616
当前位置: 首页 > 新闻动态 >
新闻详情
使用预训练Embedding,finetune DSSM模型 - 知乎
来自 : 知乎 发布时间:2021-03-25
model.__setattr__( item_input , sku_vec) model.__setattr__( user_embedding , pin_part) model.__setattr__( item_embedding , sku_part) return model


其中:“用户”端的 Embedding 和 “商品”端的 Embedding 向量维度均为128维。(输入的Embedding向量是已经预训练完毕的Embedding。例如通过word2vec模型对用户行为建模,即可得到“商品”端的 Embedding;然后通过 avg(用户产生行为的商品的Embedding),即可得到“用户”端的 Embedding)


查看一下模型的summary信息。

model = build_model()print(model.summary())


所构造的DSSM模型结构如下所示。由于未对用户和商品的ID进行Embedding操作,所以该模型的参数较少。


打印一下模型的结构。

plot_model(model, to_file= finetune_dssm_model.png )


(3)加载数据

考虑到数据量较大,所以采用 generator 模式对数据进行处理,防止加载全部数据,撑爆内存。

def file_generator(input_path, batch_size = None): while True: with open(input_path, r ) as f: pin_vec_array, sku_vec_array, y_array = [], [], [] cnt = 0  for line in f: buf = line[:-1].split( , ) pin_vec = np.array(buf[1:129], dtype=np.float32) sku_vec = np.array(buf[129:], dtype=np.float32) y = int(buf[0]) pin_vec_array.append(pin_vec) sku_vec_array.append(sku_vec) y_array.append(y) cnt += 1 if cnt % batch_size == 0: pin_vec_array = np.array(pin_vec_array) sku_vec_array = np.array(sku_vec_array) y_array = np.array(y_array) yield [pin_vec_array, sku_vec_array], y_array cnt = 0 pin_vec_array, sku_vec_array, y_array = [], [], []


本文使用小数据量进行试验,数据格式如下:

1,0.111400,0.298000,0.520000,-2.107100,-0.658500,-0.060500,-0.755700,-0.317100,0.786800,-0.051100,-0.514300,-0.772700,0.947900,0.045500,-0.146600,0.670900,0.739700,0.715800,0.519000,1.733300,-0.567100,0.475800,0.392100,0.386000,0.038900,-0.267600,-0.597700,0.365000,-1.514600,0.362100,-0.316900,0.873700,-0.208400,-0.079500,-0.401500,-0.040200,-0.545500,0.001900,0.018300,0.836700,-0.154500,-0.114000,0.648800,-0.949100,-0.074600,0.075200,0.846000,-0.234500,0.590100,-1.521400,0.374400,-0.194700,-0.309800,1.297600,0.329300,-1.250700,0.958500,-0.247100,0.083100,-1.150500,-0.535000,0.112800,-1.356800,0.879200,-0.353400,0.034500,0.241300,-0.205700,0.670600,0.633200,-0.368100,-0.754100,-0.153500,-0.475300,0.347100,0.370000,-0.380000,-0.739700,0.471700,-0.177900,0.308500,-0.058100,1.279900,0.776900,-0.088300,-1.248500,-0.973700,-0.211500,-0.210300,0.631500,-0.652400,0.866200,0.464500,-0.682000,-0.627600,-0.598000,-0.119200,0.473700,0.381500,0.567900,0.003600,-0.514900,0.536100,-0.803500,-0.619500,-0.141500,0.010400,1.268600,0.406200,-0.632000,0.250500,-0.218300,-0.168800,0.015000,-1.186700,-0.683500,1.632600,0.430000,-0.098000,0.436500,-0.068900,0.601700,0.006100,0.540800,-0.227800,-1.126100,1.165200,-0.220900,-0.202962,-3.636311,-0.504060,-2.546363,-1.235034,-0.883959,-0.348022,-0.219954,0.907031,-1.482731,0.669218,-0.477431,4.881980,3.885695,0.578319,1.427294,2.173270,-2.765083,0.004624,1.796896,1.087227,0.389897,0.604141,-1.155123,1.274209,-2.239976,-1.858146,3.090227,-0.206842,2.549677,2.601414,-0.692583,0.388238,-0.117103,-2.207036,-3.230492,-3.375904,-1.553133,2.262967,-2.091266,-0.825930,-2.791187,2.190521,-0.433236,-0.217687,-2.277860,-0.432154,-1.141102,-0.850199,-3.686642,2.615366,0.076896,-1.115686,1.734991,-1.578039,1.183485,0.641641,-2.347620,1.625458,-1.123846,1.017014,2.852135,-0.979481,0.912863,0.727238,-0.418464,-0.958715,-0.861919,0.282138,1.843323,0.175354,-1.792245,-1.370620,1.089480,0.778957,-2.377766,0.829453,-2.713742,-3.567303,-1.208078,1.233118,1.125459,4.193498,-2.459454,0.897581,1.001604,0.674028,-1.428830,-0.025545,1.150639,-3.673055,-0.666604,0.064266,0.285329,-1.370663,-0.463825,-0.842921,0.618591,1.990929,0.457696,-2.935576,0.301109,3.309814,-2.633363,-1.209220,-0.564443,-0.663638,1.399326,1.430363,-1.934421,-2.455737,-1.447479,0.263726,-0.861657,0.584651,-2.341039,3.445074,1.608032,0.724370,-0.370727,-2.025292,-0.842234,0.977376,3.447604,2.289111,2.478286,0.241298,-1.6748320,-0.804500,0.572300,-0.357900,0.472200,1.037200,0.266700,-0.023200,0.858800,-0.484500,-0.782800,0.480700,0.119000,-0.293300,-0.504600,0.374600,-0.039300,0.935600,-1.255600,-0.258700,-0.582000,-1.719200,0.307800,0.052900,0.381800,0.577100,-0.998900,0.060600,0.373900,-0.281600,0.024100,-0.332200,0.038900,0.136100,-0.002500,0.724800,0.038700,-0.148800,1.535200,-0.059800,0.322100,-0.811600,0.363400,-1.402800,0.158200,-0.507700,-0.108200,-0.051600,-0.286800,-0.345700,-0.152300,-0.201400,-0.494600,-0.716300,0.541900,-1.629700,-0.287000,-1.277400,1.244700,0.011400,0.549900,0.883000,-1.100400,-0.700300,-0.079900,-1.227600,0.047900,-0.769000,0.821900,0.783400,0.173500,0.697400,0.499200,0.602800,0.548200,-0.256100,-0.751800,1.143400,0.295100,-0.123700,-0.503200,-0.160300,-0.908800,-0.056600,0.107600,0.436000,0.679800,0.313100,-0.249200,0.779700,0.801200,-1.650800,0.089900,0.026200,-0.338600,-0.115900,0.495700,0.088600,0.526900,0.595000,0.156700,-0.736900,0.558100,-0.095900,0.072100,-0.209400,-0.999600,-0.567300,-0.017400,-0.232500,-0.538800,-0.041200,1.247400,-0.610300,0.085700,0.321900,0.478900,-0.274800,0.074000,-0.387400,-0.306000,0.204200,0.978300,-0.738800,0.267800,0.299300,0.989500,-0.597800,-0.211500,0.302525,0.926751,0.444355,2.095530,0.641599,0.585963,-0.007165,-0.225599,1.195284,0.743535,-0.283189,0.421811,-0.900632,-1.775821,0.194162,-0.131157,2.221316,-0.871263,0.611026,1.586028,0.208971,1.728807,-1.214678,-0.006417,-0.487578,-1.347446,1.257976,-1.105078,-0.641283,2.040870,-1.064334,1.848631,0.021456,1.044769,1.046561,-0.382474,0.511813,1.991464,1.541210,1.197348,-0.132546,-1.227524,-1.825696,0.637844,0.266854,0.627479,-1.939037,1.784560,-1.572687,1.319858,-0.297955,-0.648528,1.552862,-0.390313,-1.862317,-1.434988,1.003443,2.372627,0.048504,-1.178071,0.345171,-0.493632,0.708266,0.439852,1.367206,0.587270,-1.676261,1.519096,2.178505,0.398875,-0.987587,-1.099164,2.224100,-0.032785,-1.974257,-2.476301,1.279583,0.368386,0.118637,-0.390930,0.206159,-1.526931,-0.706359,-0.666684,1.660718,2.577286,2.185187,-0.082288,1.171966,-0.962591,-1.345657,3.024471,0.326179,-1.740565,0.338833,2.163889,-1.306316,0.962814,2.811996,0.795088,0.042636,-1.563679,0.169866,-0.691936,0.281116,-0.114342,-0.654810,-0.018624,-1.712857,-1.027673,0.120613,1.324406,-0.825408,0.978356,-0.286835,1.155605,-0.480432,-0.661304,0.434739,0.736817,-1.921379,1.111957,0.592577,-0.935139,-0.926583,2.585314,-0.798262,-0.515275


解释:第一个数据为label,1表示正样本,0表示负样本;第2列到第129列表示用户的Embedding数据;第130列到第257列表示商品的Embedding数据;


3. 训练DSSM模型


接下来开始训练DSSM模型。

def train_finetune_dssm(train_path, val_path, model_path, \\ n_train = None, \\ n_val = None): model = build_model() print( train samples numbers: %s % n_train) print( val samples numbers: %s % n_val) batch_size = 128 epochs = 2 train_steps_per_epoch = int(n_train / batch_size) val_steps_per_epoch = int(n_val / batch_size) train_generator = file_generator(train_path, batch_size = batch_size) val_generator = file_generator(val_path, batch_size = batch_size) early_stopping_cb = EarlyStopping(monitor = val_loss , patience = 10, restore_best_weights = True)  tensorboard_cb = TensorBoard(\\ log_dir = ./logs , \\ histogram_freq = 0, \\ write_graph = True, \\ write_grads = True, \\ write_images = True)

本文链接: http://npfine.immuno-online.com/view-728492.html

发布于 : 2021-03-25 阅读(0)