濮阳杆衣贸易有限公司

主頁(yè) > 知識(shí)庫(kù) > keras的get_value運(yùn)行越來(lái)越慢的解決方案

keras的get_value運(yùn)行越來(lái)越慢的解決方案

熱門標(biāo)簽:公司電話機(jī)器人 陜西金融外呼系統(tǒng) 海南400電話如何申請(qǐng) 白銀外呼系統(tǒng) 哈爾濱ai外呼系統(tǒng)定制 激戰(zhàn)2地圖標(biāo)注 騰訊外呼線路 廣告地圖標(biāo)注app 唐山智能外呼系統(tǒng)一般多少錢

keras 深度學(xué)習(xí)框架中g(shù)et_value函數(shù)運(yùn)行越來(lái)越慢,內(nèi)存消耗越來(lái)越大問(wèn)題

問(wèn)題描述

如上圖所示,經(jīng)過(guò)時(shí)間和內(nèi)存消耗跟蹤測(cè)試,發(fā)現(xiàn)是keras.backend.get_value() 函數(shù)導(dǎo)致的程序越來(lái)越慢,而且嚴(yán)重的造成內(nèi)存泄露;

查看該函數(shù)內(nèi)部實(shí)現(xiàn),發(fā)現(xiàn)一個(gè)主要核心是x.eval(session=get_session()),該語(yǔ)句可能是導(dǎo)致內(nèi)存泄露和運(yùn)行慢的核心語(yǔ)句; 根據(jù)查看一些博文得到了運(yùn)行得越來(lái)越慢的

原因該x.eval函數(shù)會(huì)添加新的節(jié)點(diǎn)到tf的圖中;而這也導(dǎo)致了tf的圖越來(lái)越大,內(nèi)存泄露;

解決方法

import tensorflow.keras.backend as K

def get_my_session(gpu_fraction=0.1):
    '''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''

    num_threads = os.environ.get('OMP_NUM_THREADS')
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)

    if num_threads:
        return tf.Session(config=tf.ConfigProto(
            gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
    else:
        return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

K.set_session(get_my_session())

如上圖所示, 我在使用tensorflow之前(也就是該工程文件前面),對(duì)session進(jìn)行自定義,然后用自定義的session設(shè)定keras.backend.set_session();

然后刪除get_value() 函數(shù),直接用get_value()中所使用的執(zhí)行語(yǔ)句x.eval(session=get_my_session());這樣這個(gè)添加節(jié)點(diǎn)導(dǎo)致內(nèi)存泄露的核心語(yǔ)句x.eval()就使用的是該工程統(tǒng)一自定義session,然后用tf.reset_default_graph() 對(duì)圖重置就可以了

即上圖問(wèn)題代碼修改為:

output = ctc_decode(y_pred,input_length=input_length,)
output = output[0][0]
out = output.eval(session=get_my_session())
# 刪除 K.get_value(out[0][0])
tf.reset_default_graph() # 然后重置tf圖,這句很關(guān)鍵

這樣就解決了get_value()導(dǎo)致的越來(lái)越慢的問(wèn)題;

個(gè)人認(rèn)為:這樣可能就不會(huì)總是添加新的節(jié)點(diǎn),導(dǎo)致tf圖不斷地?zé)o限變大;而是重復(fù)使用這一個(gè)自定義的節(jié)點(diǎn)。

補(bǔ)充:tensorflow與keras之間版本問(wèn)題引起get_session問(wèn)題解決辦法

1.產(chǎn)生報(bào)錯(cuò)原因

import tensorflow.keras.backend as K
def __init__(self, **kwargs):
    self.__dict__.update(self._defaults) # set up default values
    self.__dict__.update(kwargs) # and update with user overrides
    self.class_names = self._get_class()
    self.anchors = self._get_anchors()
    self.sess = K.get_session()

報(bào)錯(cuò)如下:

get_session is not available when using TensorFlow 2.0.

意思是 tf2.0 沒(méi)有 get_session

2.解決方案1

import tensorflow.python.keras.backend as K
sess = K.get_session()

3. 解決方案2

import tensorflow as tf
sess = tf.compat.v1.keras.backend.get_session()

之前一直采用方案1 解決,感覺(jué)比較方便;但是解決方案1 有其它屬性會(huì)丟失問(wèn)題

比如AttributeError: module ‘keras.backend' has no attribute image_dim_ordering

所以建議大家采用方案2

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • keras修改backend的簡(jiǎn)單方法
  • 基于keras中訓(xùn)練數(shù)據(jù)的幾種方式對(duì)比(fit和fit_generator)
  • 淺談Keras中fit()和fit_generator()的區(qū)別及其參數(shù)的坑
  • Keras保存模型并載入模型繼續(xù)訓(xùn)練的實(shí)現(xiàn)
  • TensorFlow2.0使用keras訓(xùn)練模型的實(shí)現(xiàn)
  • tensorflow2.0教程之Keras快速入門
  • 淺析關(guān)于Keras的安裝(pycharm)和初步理解
  • 基于Keras的擴(kuò)展性使用

標(biāo)簽:惠州 黑龍江 上海 四川 益陽(yáng) 常德 黔西 鷹潭

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《keras的get_value運(yùn)行越來(lái)越慢的解決方案》,本文關(guān)鍵詞  keras,的,get,value,運(yùn)行,越來(lái),;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《keras的get_value運(yùn)行越來(lái)越慢的解決方案》相關(guān)的同類信息!
  • 本頁(yè)收集關(guān)于keras的get_value運(yùn)行越來(lái)越慢的解決方案的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    格尔木市| 建德市| 青浦区| 梁平县| 海淀区| 大冶市| 隆德县| 余江县| 云阳县| 锦州市| 台北县| 灵宝市| 信宜市| 大同市| 拉萨市| 柘城县| 日喀则市| 丰县| 化德县| 寿光市| 深圳市| 阳曲县| 德钦县| 化州市| 阜康市| 光泽县| 都昌县| 东乡族自治县| 高邮市| 高邑县| 抚州市| 阜南县| 芦山县| 牡丹江市| 渭南市| 绵阳市| 利川市| 临猗县| 浮山县| 搜索| 武川县|