[問題] 利用神經網路趨近0或1來判別(驗證問題)

看板Python作者 (宇宙射線)時間5年前 (2018/09/14 01:03), 5年前編輯推噓7(7021)
留言28則, 7人參與, 5年前最新討論串1/1
版上大神好,小弟大三EE還是此領域新手,目前使用tensorflow來實作 求大神解答QQ,本人寒酸有給建議的人一率奉上微薄100P 我稍微說明一下程式目的及架構,裡面code也有稍微打個註解 我有找到一份資料是是關於航線是否延遲共有8個特徵 裡面有九個元素前八個是特徵,最後一項是答案0或1 然後一共有1175筆資料 現在希望架設個神經網路訓練它,接著再拿別筆資料1*8矩陣送進去使它趨近0或1 架構應該是對的但想餵資料驗證時有問題QQ 當然裡面有先對資料做正規化,使用公式為Xnorm=(X-Xmin)/(Xmax-Xmin) 架構為inputs layer有8個神經元接著隱藏層1有6個神經元, 隱藏層2有3個神經元再輸出層一個神經元,也就是讓它1*8矩陣收斂為1*1矩陣 前面50筆當驗證用剩餘部分拿來訓練 大致就是隨機取一筆訓練30000次 這是我的程式檔: https://github.com/Hawkingfans/airline_delayed_judgment/blob/master/.gitignore 以及data: https://goo.gl/RNRxnD 放在github裡 但執行結果卻是: https://imgur.com/ZUwSqh8
每個output都一樣....猜是訓練方式有問題? 不知道可以如何修改? 但因發現似乎解法是想驗證一筆資料就要再重新train一次?不確定 所以暫時修改為:(只動with裡面) with tf.Session() as sess: sess.run(init) for k in range(0,10): #訓練部分(扣除前10筆資料所剩餘的) for i in range(20000): n = np.random.randint(10, 1175 ) line_data = norm_data[np.newaxis, n, 0: 8] expect = data[n, 8] expected= np.reshape(expect,[1,1]) sess.run(train_step, feed_dict={xs: line_data,ys: expected}) #驗證部分(前10筆資料) examine_data = norm_data[np.newaxis, k, 0: 8] answer = norm_data[k, 8] output = sess.run( y,feed_dict={xs: examine_data}) print("output =",output) print("answer=", answer) print("======================") 其實就是變成雙層for迴圈而已但因為這樣驗證會太久...所以就切10筆而已 結果如圖: https://imgur.com/jDSPjWk
看起來似乎有達到目的但這種做法要是切300筆資料要驗證, 不就變300*20000次....這會跑超久捏 = =" 且主要是這方法我拿去問教授,他說這樣是不對的......... 他說應該是訓練完好,驗證再另個block也就是不用重新train。 想請教各位如何修改才是正確的訓練? 希望以原本的方法為主,已經為此想破頭好幾天QAQ 感激不盡 -- ※ 發信站: 批踢踢實業坊(ptt.cc), 來自: 220.135.42.38 ※ 文章網址: https://www.ptt.cc/bbs/Python/M.1536858230.A.7FF.html

09/14 01:19, 5年前 , 1F
這樣看起來是分類問題,可以將label做one-hoe encodi
09/14 01:19, 1F

09/14 01:19, 5年前 , 2F
ng ,變成[0, 1], [1,0] ,output num = 2
09/14 01:19, 2F
抱歉不太懂您的意思QQ,剛剛搜了一下one-hot encoding好像是對字元做處理? 不過txt檔讀進去只是個1175*9的矩陣

09/14 02:02, 5年前 , 3F
嗯,這問題很大,第一,data太少,DL效果不好,可以試
09/14 02:02, 3F

09/14 02:02, 5年前 , 4F
試ML。第二,balance的問題。第三,要做 feature engin
09/14 02:02, 4F

09/14 02:02, 5年前 , 5F
eering
09/14 02:02, 5F

09/14 02:03, 5年前 , 6F
tensorflow 網路上很多範例,去看看應該有幫助,新手
09/14 02:03, 6F

09/14 02:04, 5年前 , 7F
直接碰tensorflow 有點困難,建議重基礎的ML開始,有
09/14 02:04, 7F

09/14 02:04, 5年前 , 8F
很多概念要知道
09/14 02:04, 8F
其實我不確定除了對train_step 做訓練還有無其他作法,因為教授要我們 直接以tensorflow開始QQ所以現在很像盲人摸象...之前唯一先接觸入門 就莫凡前半部的影片,不過謝謝你給的觀念我會研究一下ML

09/14 02:10, 5年前 , 9F
你沒用sigmoid
09/14 02:10, 9F

09/14 02:11, 5年前 , 10F
一般分類問題loss是用cross entropy
09/14 02:11, 10F
原本有試過sigmoid不過發現relu好像也能收斂所以就沒用惹, 雖然目前沒很了解cross entropy跟方均根的差別,但之前也有換成過cross...QQ

09/14 10:22, 5年前 , 11F
tensorflow 推 Hvass-Labs 的教學
09/14 10:22, 11F
謝謝推薦晚上回去來看看 以上都先發錢摟~ ※ 編輯: cosmicray (218.161.49.97), 09/14/2018 12:07:00

09/14 12:53, 5年前 , 12F
還沒仔細看程式,有幾個問題可能有幫助
09/14 12:53, 12F

09/14 12:54, 5年前 , 13F
1.神經元數量太少,網路根本沒辦法fit
09/14 12:54, 13F

09/14 12:57, 5年前 , 14F
2. 取資料訓練的時候應該是一次取一個batch下去訓練,
09/14 12:57, 14F

09/14 12:57, 5年前 , 15F
然後輪完整個資料集再重複動作
09/14 12:57, 15F

09/14 13:01, 5年前 , 16F
3. 我不太懂驗證的時候那個20000是做什麼的?驗證只要
09/14 13:01, 16F

09/14 13:01, 5年前 , 17F
輸入x跑一次就能得到一個預測的y了
09/14 13:01, 17F

09/14 13:03, 5年前 , 18F
另外如果你們教授許可的話,新手從keras或pytorch入
09/14 13:03, 18F

09/14 13:03, 5年前 , 19F
門會簡單非常多
09/14 13:03, 19F
其實那是訓練20000次的意思,但因為老師說我們嘗試統計一下正確率。 因此想說寫一個for統計一下答案跟出來的y正確率。 但發現做法1也就是github裡code發現出來的y都是一樣,因此無法統計。 但改成做法2也就是我修改後的用10包住20000等於每驗證一次就要重新train 不確定這想法是否對的?如果對的這樣未來驗證data變多就會跑很久QQ ※ 編輯: cosmicray (140.138.180.144), 09/14/2018 14:26:32

09/14 15:01, 5年前 , 20F
歡迎試著到DataScience板發文唷,那邊是機械學習專板
09/14 15:01, 20F
謝謝因為原本怕那邊會有語法不同問題 ※ 編輯: cosmicray (140.138.180.144), 09/14/2018 16:11:21

09/14 18:59, 5年前 , 21F
你可能要先搞懂一下batch這個概念,在神經網路中
09/14 18:59, 21F

09/14 19:00, 5年前 , 22F
一次train一筆跟一次train一個batch結果可能差很多
09/14 19:00, 22F

09/14 19:00, 5年前 , 23F
再來就是上面提到的神經網路太淺太瘦
09/14 19:00, 23F

09/14 19:01, 5年前 , 24F
試試看先把神經元數量加大 可能幾百以上
09/14 19:01, 24F

09/14 19:02, 5年前 , 25F
還有一點,你train完20000次之後,接著
09/14 19:02, 25F

09/14 19:03, 5年前 , 26F
驗證50筆資料,而不是你文中的每train20000再
09/14 19:03, 26F

09/14 19:03, 5年前 , 27F
驗證一筆
09/14 19:03, 27F
謝謝給的建議,會補齊相關觀念 的確都只試過trian一筆

09/15 05:39, 5年前 , 28F
何不用 keras?
09/15 05:39, 28F
嘛...老師要求從這入門,很盲人摸象QQ ※ 編輯: cosmicray (114.137.96.117), 09/15/2018 22:11:41
文章代碼(AID): #1RcfXsV_ (Python)