2015-05-31 10 views
5

Ich brauche zu initialisieren ein 3D-Tensor mit einem Index abhängige Funktion in torch7 zu initialisieren, dhSchneller Weg, um einen Tensor in torch7

func = function(i,j,k) --i, j is the index of an element in the tensor 
    return i*j*k  --do operations within func which're dependent of i, j 
end 

dann initialisieren ich ein 3D-Tensor A wie folgt aus:

for i=1,A:size(1) do 
    for j=1,A:size(2) do 
     for k=1,A:size(3) do 
      A[{i,j,k}] = func(i,j,k) 
     end 
    end 
end 

Aber dieser Code läuft sehr langsam, und ich fand es dauert 92% der gesamten Laufzeit. Gibt es effizientere Möglichkeiten, einen 3D-Tensor in Fackel 7 zu initialisieren?

+0

Was die Größe von 'A'? – ryanpattison

Antwort

7

finden Sie in der Dokumentation für die Tensor:apply

Diese Funktionen gelten, eine Funktion zu jedem Element des Tensor auf , die das Verfahren (Selbst-) genannt wird. Diese Methoden sind viel schneller als mit einer for-Schleife in Lua.

Das Beispiel in der Dokumentation initialisiert ein 2D-Array basierend auf seinem Index i (im Speicher). Unten ist ein erweitertes Beispiel für 3 Dimensionen und darunter für N-D Tensoren. die Anwendung mit der Methode ist viel, viel schneller auf meiner Maschine:

require 'torch' 

A = torch.Tensor(100, 100, 1000) 
B = torch.Tensor(100, 100, 1000) 

function func(i,j,k) 
    return i*j*k  
end 

t = os.clock() 
for i=1,A:size(1) do 
    for j=1,A:size(2) do 
     for k=1,A:size(3) do 
      A[{i, j, k}] = i * j * k 
     end 
    end 
end 
print("Original time:", os.difftime(os.clock(), t)) 

t = os.clock() 
function forindices(A, func) 
    local i = 1 
    local j = 1 
    local k = 0 
    local d3 = A:size(3) 
    local d2 = A:size(2) 
    return function() 
    k = k + 1 
    if k > d3 then 
     k = 1 
     j = j + 1 
     if j > d2 then 
     j = 1 
     i = i + 1 
     end 
    end 
    return func(i, j, k) 
    end 
end 

B:apply(forindices(A, func)) 
print("Apply method:", os.difftime(os.clock(), t)) 

EDIT

Dieses ist für jede Tensor Objekt arbeiten:

function tabulate(A, f) 
    local idx = {} 
    local ndims = A:dim() 
    local dim = A:size() 
    idx[ndims] = 0 
    for i=1, (ndims - 1) do 
    idx[i] = 1 
    end 
    return A:apply(function() 
    for i=ndims, 0, -1 do 
     idx[i] = idx[i] + 1 
     if idx[i] <= dim[i] then 
     break 
     end 
     idx[i] = 1 
    end 
    return f(unpack(idx)) 
    end) 
end 

-- usage for 3D case. 
tabulate(A, function(i, j, k) return i * j * k end) 
+0

@delteil ja, danke. – ryanpattison

+0

Gern geschehen! (Kommentar entfernt, da es danach nicht mehr relevant ist [edit] (http://stackoverflow.com/revisions/30560653/5)) – deltheil

+0

große Antwort! Solange der Funktor richtig JIT-kompiliert werden kann, wird es sehr schnell sein (nahe C-Geschwindigkeiten) – smhx