十年网站开发经验 + 多家企业客户 + 靠谱的建站团队
量身定制 + 运营维护+专业推广+无忧售后,网站问题一站解决
创新互联www.cdcxhl.cn八线动态BGP香港云服务器提供商,新人活动买多久送多久,划算不套路!
成都创新互联致力于互联网品牌建设与网络营销,包括成都网站建设、做网站、SEO优化、网络推广、整站优化营销策划推广、电子商务、移动互联网营销等。成都创新互联为不同类型的客户提供良好的互联网应用定制及解决方案,成都创新互联核心团队十载专注互联网开发,积累了丰富的网站经验,为广大企业客户提供一站式企业网站建设服务,在网站建设行业内树立了良好口碑。不懂Keras模型转成tensorflow中.pb的方法?其实想解决这个问题也不难,下面让小编带着大家一起学习怎么去解决,希望大家阅读完这篇文章后大所收获。
Keras的.h6模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码
from keras.models import Model from keras.layers import Dense, Dropout from keras.applications.mobilenet import MobileNet from keras.applications.mobilenet import preprocess_input from keras.preprocessing.image import load_img, img_to_array import tensorflow as tf from keras import backend as K import os base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None) x = Dropout(0.75)(base_model.output) x = Dense(10, activation='softmax')(x) model = Model(base_model.input, x) model.load_weights('mobilenet_weights.h6') def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): from tensorflow.python.framework.graph_util import convert_variables_to_constants graph = session.graph with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def() if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names) return frozen_graph output_graph_name = 'NIMA.pb' output_fld = '' #K.set_learning_phase(0) print('input is :', model.input.name) print ('output is:', model.output.name) sess = K.get_session() frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name]) from tensorflow.python.framework import graph_io graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False) print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))