五個簡單的範例 在 GitHub 上編輯

以下是開始使用 Torch 的五個入門步驟!本教學課程假設您已透過

require 'torch'

下載套件 torch

或您使用 REPL th (它會自動下載)。

1. 定義一個正定的二次形式

  • 我們在此使用一些 torch 函數
  • rand(),可建立採用均勻分布繪製的張量
  • t(),可對張量進行轉置 (請注意,它會傳回一個新的檢視)
  • dot(),可在兩個張量之間執行點積
  • eye(),可傳回單位矩陣

* 對矩陣進行的運算 (可執行矩陣-向量或矩陣-矩陣的乘法)

torch.manualSeed(1234)
-- choose a dimension
N = 5

-- create a random NxN matrix
A = torch.rand(N, N)

-- make it symmetric positive
A = A*A:t()

-- make it definite
A:add(0.001, torch.eye(N))

-- add a linear term
b = torch.rand(N)

-- create the quadratic form
function J(x)
   return 0.5*x:dot(A*x)-b:dot(x)
end

我們的步驟是先確認亂數種子對所有人都相同

print(J(torch.rand(N)))

列印函數值 (在此以亂數點為例) 的方式非常簡單

2. 找出確切的最小值

xs = torch.inverse(A)*b
print(string.format('J(x^*) = %g', J(xs)))

我們可以求矩陣的逆矩陣 (可能在數值上並非最佳做法)

3. 使用梯度下降法搜尋最小值

function dJ(x)
  return A*x-b
end

我們首先定義 J(x)x 的梯度

x = torch.rand(N)

接著定義一些當前的解答

lr = 0.01
for i=1,20000 do
  x = x - dJ(x)*lr
  -- we print the value of the objective function at each iteration
  print(string.format('at iter %d J(x) = %f', i, J(x)))
end

然後進行梯度下降法 (搭配既有的學習率 lr) 一段時間

...
at iter 19995 J(x) = -3.135664
at iter 19996 J(x) = -3.135664
at iter 19997 J(x) = -3.135665
at iter 19998 J(x) = -3.135665
at iter 19999 J(x) = -3.135665
at iter 20000 J(x) = -3.135666

您應該會看到

4. 使用 optim 套件

luarocks install optim

想要使用更進階的最佳化技術,例如共軛梯度法或 LBFGS?optim 套件就能滿足您的需求!首先,我們需要安裝它

local 變數

local A = torch.rand(N, N)

實際上,使用全域變數絕非好主意。請在每個位置使用 local。在我們的範例中,我們將所有內容都定義為全域,方便在直譯器指令行中剪貼和貼上。若定義 local 例如

將只能用於當前的範圍,也就是在執行直譯器時,僅限於目前的輸入行。後續的幾行無法使用這個 local

do
   local A = torch.rand(N, N)
   print(A)
end
print(A)

在 lua 中,可以使用 do...end 指令定義範圍

使用 upvalue 定義閉包

我們需要定義一個同時傳回 J(x)dJ(x) 的閉包。在此我們透過 do...end 定義範圍,如此才能讓 local 變數 neval 成為 JdJ(x) 的 upvalue:只有 JdJ(x) 能讀取該變數。請注意,在腳本中並不需要有 do...end 範圍,因為 neval 的範圍會持續到腳本檔案結束為止(而不是像命令提示字元那樣到該列結束)。

do
   local neval = 0
   function JdJ(x)
      local Jx = J(x)
      neval = neval + 1
      print(string.format('after %d evaluations J(x) = %f', neval, Jx))
      return Jx, dJ(x)
   end
end

透過 optim 訓練

套件並未預設載入,我們需要 require

require 'optim'

我們首先為共軛梯度定義一個狀態

state = {
   verbose = true,
   maxIter = 100
}

現在開始訓練

x = torch.rand(N)
optim.cg(JdJ, x, state)

你應該會看到類似以下的內容

after 120 evaluation J(x) = -3.136835
after 121 evaluation J(x) = -3.136836
after 122 evaluation J(x) = -3.136837
after 123 evaluation J(x) = -3.136838
after 124 evaluation J(x) = -3.136840
after 125 evaluation J(x) = -3.136838

5. 繪製

繪製圖形的方式有很多種。舉例來說,一個人可以使用最近發表的 iTorch 套件。在此,我們將使用 gnuplot

luarocks install gnuplot

儲存中間函數的評估結果

我們稍稍修改先前那個閉包,以便將中間函數的評估結果(以及到目前為止的實際訓練時間)儲存起來

evaluations = {}
time = {}
timer = torch.Timer()
neval = 0
function JdJ(x)
   local Jx = J(x)
   neval = neval + 1
   print(string.format('after %d evaluations, J(x) = %f', neval, Jx))
   table.insert(evaluations, Jx)
   table.insert(time, timer:time().real)
   return Jx, dJ(x)
end

現在我們可以開始訓練

state = {
   verbose = true,
   maxIter = 100
}

x0 = torch.rand(N)
cgx = x0:clone() -- make a copy of x0
timer:reset()
optim.cg(JdJ, cgx, state)

-- we convert the evaluations and time tables to tensors for plotting:
cgtime = torch.Tensor(time)
cgevaluations = torch.Tensor(evaluations)

加入 Stochastic Gradient Descent 支援

使用 optim 來加入隨機梯度訓練

evaluations = {}
time = {}
neval = 0
state = {
  lr = 0.1
}

-- we start from the same starting point than for CG
x = x0:clone()

-- reset the timer!
timer:reset()

-- note that SGD optimizer requires us to do the loop
for i=1,1000 do
  optim.sgd(JdJ, x, state)
  table.insert(evaluations, Jx)
end
  
sgdtime = torch.Tensor(time)
sgdevaluations = torch.Tensor(evaluations)

最後的繪製結果

現在我們可以來繪製圖形了。第一種簡單的方法是使用 gnuplot.plot(x, y)。在此,我們會先加上 gnuplot.figure(),以確保圖形畫在不同的圖形上。

require 'gnuplot'
gnuplot.figure(1)
gnuplot.title('CG loss minimisation over time')
gnuplot.plot(cgtime, cgevaluations)

gnuplot.figure(2)
gnuplot.title('SGD loss minimisation over time')
gnuplot.plot(sgdtime, sgdevaluations)

更進階的方法是將所有內容繪製在同一個圖形上,如下所示。在此,我們將所有內容儲存在 PNG 檔案中。

gnuplot.pngfigure('plot.png')
gnuplot.plot(
   {'CG',  cgtime,  cgevaluations,  '-'},
   {'SGD', sgdtime, sgdevaluations, '-'})
gnuplot.xlabel('time (s)')
gnuplot.ylabel('J(x)')
gnuplot.plotflush()

CG vs SGD