2016-07-17 19 views
1

Ich versuche, einen parallelen Dataloader der torch-dataframe hinzuzufügen, um torchnet compatibility hinzuzufügen.Torch out of memory in Thread bei der Verwendung von torch.serialize zweimal

  1. Ein Grundansatz wird außerhalb der Fäden
  2. Der Ansatz wird serialisiert und auf den Faden
  3. im Gewinde der Stapel geladen wird deserialisiert und wandelt das: Ich habe die tnt.ParallelDatasetIteratorchanged it und so die verwendet Stapeldaten zu Tensoren
  4. Die Tensoren werden in einer Tabelle mit den Schlüsseln input und target zurückgegeben, um die tnt.Engine Einstellung zu entsprechen.

Das Problem tritt zum zweiten Mal der enque mit einem Fehler aufgerufen wird: .../torch_distro/install/bin/luajit: not enough memory. Ich arbeite derzeit nur mit mnist mit einem angepassten mnist-example. Die enque Schleife sieht nun wie folgt (mit Debug-Speicherausgabe):

-- `samplePlaceholder` stands in for samples which have been 
-- filtered out by the `filter` function 
local samplePlaceholder = {} 

-- The enque does the main loop 
local idx = 1 
local function enqueue() 
    while idx <= size and threads:acceptsjob() do 
    local batch, reset = self.dataset:get_batch(batch_size) 

    if (reset) then 
     idx = size + 1 
    else 
     idx = idx + 1 
    end 

    if (batch) then 
     local serialized_batch = torch.serialize(batch) 

     -- In the parallel section only the to_tensor is run in parallel 
     -- this should though be the computationally expensive operation 
     threads:addjob(
     function(argList) 
      io.stderr:write("\n Start"); 
      io.stderr:write("\n 1: " ..tostring(collectgarbage("count"))) 
      local origIdx, serialized_batch, samplePlaceholder = unpack(argList) 

      io.stderr:write("\n 2: " ..tostring(collectgarbage("count"))) 
      local batch = torch.deserialize(serialized_batch) 
      serialized_batch = nil 

      collectgarbage() 
      collectgarbage() 

      io.stderr:write("\n 3: " .. tostring(collectgarbage("count"))) 
      batch = transform(batch) 

      io.stderr:write("\n 4: " .. tostring(collectgarbage("count"))) 
      local sample = samplePlaceholder 
      if (filter(batch)) then 
      sample = {} 
      sample.input, sample.target = batch:to_tensor() 
      end 
      io.stderr:write("\n 5: " ..tostring(collectgarbage("count"))) 

      collectgarbage() 
      collectgarbage() 
      io.stderr:write("\n 6: " ..tostring(collectgarbage("count"))) 

      io.stderr:write("\n End \n"); 
      return { 
      sample, 
      origIdx 
      } 
     end, 
     function(argList) 
      sample, sampleOrigIdx = unpack(argList) 
     end, 
     {idx, serialized_batch, samplePlaceholder} 
    ) 
    end 
    end 
end 

ich collectgarbage habe bestreut und auch versucht, alle Objekte zu entfernen, nicht benötigt. Der Speicherausgang ist ziemlich einfach:

Start 
1: 374840.87695312 
2: 374840.94433594 
3: 372023.79101562 
4: 372023.85839844 
5: 372075.41308594 
6: 372023.73632812 
End 

Die Funktion, die die enque Schleifen ist die ungeordnete Funktion, die trivial ist (die Speicherfehler beim zweiten enque und die ausgelöst wird):

iterFunction = function() 
    while threads:hasjob() do 
    enqueue() 
    threads:dojob() 
    if threads:haserror() then 
     threads:synchronize() 
    end 
    enqueue() 

    if table.exact_length(sample) > 0 then 
     return sample 
    end 
    end 
end 

Antwort

1

Also das Problem war die torch.serialize, wo die Funktion in der Einrichtung den gesamten Datensatz an die Funktion gekoppelt. Beim Hinzufügen:

serialized_batch = nil 
collectgarbage() 
collectgarbage() 

Das Problem wurde behoben. Ich wollte außerdem wissen, was so viel Platz in Anspruch nahm, und der Täter stellte sich heraus, dass ich die Funktion in einer Umgebung definiert hatte, in der ein großer Datensatz mit der Funktion verflochten war und die Größe massiv anwuchs. Hier ist die ursprüngliche Definition des Datums lokalen

mnist = require 'mnist' 
local dataset = mnist[mode .. 'dataset']() 

-- PROBLEMATIC LINE BELOW -- 
local ext_resource = dataset.data:reshape(dataset.data:size(1), 
    dataset.data:size(2) * dataset.data:size(3)):double() 

-- Create a Dataframe with the label. The actual images will be loaded 
-- as an external resource 
local df = Dataframe(
    Df_Dict{ 
    label = dataset.label:totable(), 
    row_id = torch.range(1, dataset.data:size(1)):totable() 
    }) 

-- Since the mnist package already has taken care of the data 
-- splitting we create a single subsetter 
df:create_subsets{ 
    subsets = Df_Dict{core = 1}, 
    class_args = Df_Tbl({ 
    batch_args = Df_Tbl({ 
     label = Df_Array("label"), 
     data = function(row) 
     return ext_resource[row.row_id] 
     end 
    }) 
    }) 
} 

es stellt mich heraus, dass die Zeile entfernen, die ich markierte die Speichernutzung von 358 Mb reduziert bis 0,0008 Mb! Der Code, den ich für die Prüfung der Leistung verwendet wurde:

local mem = {} 
table.insert(mem, collectgarbage("count")) 

local ser_data = torch.serialize(batch.dataset) 
table.insert(mem, collectgarbage("count")) 

local ser_retriever = torch.serialize(batch.batchframe_defaults.data) 
table.insert(mem, collectgarbage("count")) 

local ser_raw_retriever = torch.serialize(function(row) 
    return ext_resource[row.row_id] 
end) 
table.insert(mem, collectgarbage("count")) 

local serialized_batch = torch.serialize(batch) 
table.insert(mem, collectgarbage("count")) 

for i=2,#mem do 
    print(i-1, (mem[i] - mem[i-1])/1024) 
end 

Welche produziert ursprünglich die Ausgabe:

1 0.0082607269287109 
2 358.23344707489 
3 0.0017471313476562 
4 358.90182781219 

und nach dem Update:

1 0.0094480514526367 
2 0.00080204010009766 
3 0.00090408325195312 
4 0.010146141052246 

Ich versuchte, die setfenv für die Verwendung von Funktion, aber das Problem wurde nicht gelöst. Es gibt immer noch eine Leistungseinbuße beim Senden der serialisierten Daten an den Thread, aber das Hauptproblem ist gelöst, und ohne den teuren Datenabruf ist die Funktion wesentlich kleiner.