Memoisasi di Haskell?

139

Ada petunjuk tentang cara menyelesaikan secara efisien fungsi berikut di Haskell, untuk bilangan besar (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

Saya telah melihat contoh memoisasi di Haskell untuk menyelesaikan bilangan fibonacci, yang melibatkan komputasi (malas) semua bilangan fibonacci hingga n yang diperlukan. Tetapi dalam kasus ini, untuk n tertentu, kita hanya perlu menghitung sangat sedikit hasil antara.

Terima kasih

Angel de Vicente
sumber
110
Hanya dalam arti bahwa itu adalah beberapa pekerjaan yang saya lakukan di rumah :-)
Angel de Vicente

Jawaban:

258

Kita dapat melakukannya dengan sangat efisien dengan membuat struktur yang dapat kita indeks dalam waktu sub-linier.

Tapi pertama-tama,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Mari kita definisikan f, tetapi gunakan 'rekursi terbuka' daripada memanggil dirinya sendiri secara langsung.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Anda bisa mendapatkan keliru fdengan menggunakanfix f

Ini akan memungkinkan Anda menguji fapakah yang Anda maksud untuk nilai-nilai kecil fdengan memanggil, misalnya:fix f 123 = 144

Kita bisa mengingat ini dengan mendefinisikan:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Itu berkinerja cukup baik, dan menggantikan apa yang akan memakan waktu O (n ^ 3) dengan sesuatu yang mengingat hasil antara.

Tetapi masih membutuhkan waktu linier hanya untuk mengindeks untuk menemukan jawaban yang dimo mf. Artinya hasil seperti:

*Main Data.List> faster_f 123801
248604

dapat ditoleransi, tetapi hasilnya tidak berskala lebih baik dari itu. Kami bisa lebih baik!

Pertama, mari kita definisikan pohon tak terbatas:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Dan kemudian kita akan menentukan cara untuk mengindeks ke dalamnya, sehingga kita dapat menemukan simpul dengan indeks ndalam waktu O (log n) :

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... dan kami mungkin merasa nyaman menggunakan pohon yang penuh dengan bilangan asli, jadi kami tidak perlu bermain-main dengan indeks tersebut:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Karena kami dapat mengindeks, Anda dapat mengubah pohon menjadi daftar:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Anda dapat memeriksa pekerjaan sejauh ini dengan memverifikasi yang toList natsmemberi Anda[0..]

Sekarang,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

bekerja seperti daftar di atas, tetapi alih-alih mengambil waktu linier untuk menemukan setiap node, dapat mengejarnya dalam waktu logaritmik.

Hasilnya jauh lebih cepat:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

Faktanya, ini jauh lebih cepat sehingga Anda dapat melalui dan mengganti Intdengan di Integeratas dan mendapatkan jawaban yang sangat besar hampir secara instan

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Edward KMETT
sumber
3
Saya mencoba kode ini dan, yang menarik, f_faster sepertinya lebih lambat dari f. Saya kira referensi daftar itu benar-benar memperlambat segalanya. Definisi nats dan indeks tampak cukup misterius bagi saya, jadi saya telah menambahkan jawaban saya sendiri yang mungkin dapat memperjelas semuanya.
Pitarou
5
Kasus daftar tak terbatas harus berurusan dengan daftar tertaut 111111111 item. Kasus pohon berurusan dengan log n * jumlah node tercapai.
Edward KMETT
2
yaitu versi daftar harus menghasilkan banyak node dalam daftar, sedangkan versi hierarki menghindari pembuatan banyak node.
Tom Ellis
7
Saya tahu ini adalah posting yang agak lama, tetapi tidak f_treeboleh didefinisikan dalam whereklausa untuk menghindari menyimpan jalur yang tidak diperlukan di pohon di seluruh panggilan?
dfeuer
17
Alasan untuk memasukkannya ke dalam CAF adalah Anda bisa mendapatkan memoisasi di seluruh panggilan. Jika saya memiliki panggilan mahal yang saya ingat, maka saya mungkin akan meninggalkannya di CAF, karena itu teknik yang ditunjukkan di sini. Dalam penerapan nyata tentu saja ada trade-off antara manfaat dan biaya memoisasi permanen. Meskipun, mengingat pertanyaannya adalah tentang bagaimana mencapai memoization, saya pikir itu akan menyesatkan untuk menjawab dengan teknik yang sengaja menghindari memoization di seluruh panggilan, dan jika tidak ada yang lain maka komentar ini di sini akan menunjukkan kepada orang-orang fakta bahwa ada seluk-beluk. ;)
Edward KMETT
17

Jawaban Edward adalah permata yang luar biasa sehingga saya telah menduplikasinya dan memberikan implementasi memoListdan memoTreekombinator yang mengenang suatu fungsi dalam bentuk rekursif terbuka.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
Tom Ellis
sumber
12

Bukan cara yang paling efisien, tetapi tidak memo:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

saat meminta f !! 144, dicentang bahwa f !! 143ada, tetapi nilai pastinya tidak dihitung. Ini masih ditetapkan sebagai hasil perhitungan yang tidak diketahui. Satu-satunya nilai tepat yang dihitung adalah yang dibutuhkan.

Jadi awalnya, sejauh berapa yang sudah dihitung, program itu tidak tahu apa-apa.

f = .... 

Saat kami membuat permintaan f !! 12, itu mulai melakukan beberapa pencocokan pola:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Sekarang mulai menghitung

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Ini secara rekursif membuat permintaan lain pada f, jadi kami menghitung

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Sekarang kita bisa mendapatkan kembali beberapa

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Artinya, program tersebut sekarang mengetahui:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Terus menetes:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Artinya, program tersebut sekarang mengetahui:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Sekarang kita lanjutkan dengan perhitungan kita tentang f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Artinya, program tersebut sekarang mengetahui:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Sekarang kita lanjutkan dengan perhitungan kita tentang f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Artinya, program tersebut sekarang mengetahui:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Jadi perhitungannya dilakukan dengan cukup malas. Program mengetahui bahwa beberapa nilai f !! 8ada, bahwa itu sama dengan g 8, tetapi tidak tahu apa g 8itu.

rampion
sumber
Terima kasih untuk yang ini. Bagaimana Anda membuat dan menggunakan ruang solusi 2 dimensi? Apakah itu daftar daftarnya? dang n m = (something with) f!!a!!b
vikingsteve
1
Tentu, kamu bisa. Untuk solusi nyata, saya mungkin akan menggunakan perpustakaan memoization
rampion
Sayangnya itu O (n ^ 2).
Qumeric
9

Seperti yang dinyatakan dalam jawaban Edward Kmett, untuk mempercepat, Anda perlu menyimpan cache penghitungan yang mahal dan dapat mengaksesnya dengan cepat.

Untuk menjaga agar fungsinya tetap non monad, solusi membangun pohon malas tak terbatas, dengan cara yang tepat untuk mengindeksnya (seperti yang ditunjukkan di posting sebelumnya) memenuhi tujuan itu. Jika Anda melepaskan sifat non-monad dari fungsi tersebut, Anda dapat menggunakan container asosiatif standar yang tersedia di Haskell dalam kombinasi dengan monad “mirip-status” (seperti State atau ST).

Meskipun kelemahan utamanya adalah Anda mendapatkan fungsi non-monad, Anda tidak perlu mengindeks struktur itu sendiri lagi, dan cukup menggunakan implementasi standar dari wadah asosiatif.

Untuk melakukannya, pertama-tama Anda perlu menulis ulang fungsi Anda untuk menerima jenis monad:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Untuk pengujian Anda, Anda masih bisa menentukan fungsi yang tidak melakukan memoisasi menggunakan Data.Function.fix, meskipun sedikit lebih panjang:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Anda kemudian dapat menggunakan monad Status dalam kombinasi dengan Data.Map untuk mempercepat:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Dengan perubahan kecil, Anda dapat menyesuaikan kode agar berfungsi dengan Data.HashMap sebagai gantinya:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Alih-alih struktur data persisten, Anda juga dapat mencoba struktur data yang dapat berubah (seperti Data.HashTable) dalam kombinasi dengan monad ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

Dibandingkan dengan implementasi tanpa memoisasi, salah satu implementasi ini memungkinkan Anda, untuk masukan yang sangat besar, mendapatkan hasil dalam hitungan mikro-detik daripada harus menunggu beberapa detik.

Menggunakan Criterion sebagai patokan, saya dapat mengamati bahwa implementasi dengan Data.HashMap sebenarnya berkinerja sedikit lebih baik (sekitar 20%) daripada Data.Map dan Data.HashTable yang pengaturan waktunya sangat mirip.

Saya menemukan hasil benchmark agak mengejutkan. Perasaan awal saya adalah bahwa HashTable akan mengungguli implementasi HashMap karena dapat berubah. Mungkin ada beberapa cacat kinerja yang tersembunyi dalam penerapan terakhir ini.

Quentin
sumber
2
GHC melakukan pekerjaan yang sangat baik dalam mengoptimalkan di sekitar struktur yang tidak dapat diubah. Intuisi dari C tidak selalu berhasil.
John Tyree
8

Ini adalah tambahan dari jawaban luar biasa Edward Kmett.

Ketika saya mencoba kodenya, definisi dari natsdan indextampaknya cukup misterius, jadi saya menulis versi alternatif yang menurut saya lebih mudah untuk dipahami.

Saya mendefinisikan indexdan natsdalam istilah index'dan nats'.

index' t nditentukan selama rentang [1..]. (Ingatlah yang index tdidefinisikan selama rentang [0..].) Ia bekerja mencari pohon dengan memperlakukan nsebagai string bit, dan membaca bit secara terbalik. Jika bitnya 1, dibutuhkan cabang sebelah kanan. Jika bit 0, dibutuhkan cabang kiri. Ini berhenti ketika mencapai bit terakhir (yang harus a 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Sama seperti natsyang didefinisikan indexsehingga index nats n == nselalu benar, nats'didefinisikan untuk index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Sekarang, natsdan indexsecara sederhana nats'dan index'tetapi dengan nilai-nilai digeser oleh 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Pitarou
sumber
Terima kasih. Saya mengingat fungsi multivariasi, dan ini benar-benar membantu saya mengetahui apa yang sebenarnya dilakukan index dan nats.
Kittsil
3

Beberapa tahun kemudian, saya melihat ini dan menyadari ada cara sederhana untuk membuat memo ini dalam waktu linier menggunakan zipWithdan fungsi pembantu:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilatememiliki properti berguna itu dilate n xs !! i == xs !! div i n.

Jadi, misalkan kita diberi f (0), ini menyederhanakan komputasi menjadi

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Sangat mirip dengan deskripsi masalah awal kita, dan memberikan solusi linier ( sum $ take n fsakan mengambil O (n)).

rampion
sumber
2
jadi ini solusi generatif (korekursif?), atau pemrograman dinamis. Mengambil O (1) kali untuk setiap nilai yang dihasilkan, seperti yang biasa dilakukan Fibonacci. Bagus! Dan solusi EKMETT seperti logaritmik-Fibonacci besar, mencapai angka-angka besar jauh lebih cepat, melewati sebagian besar persilangan. Apakah ini benar?
Will Ness
atau mungkin lebih dekat dengan angka Hamming, dengan tiga penunjuk ke belakang ke dalam urutan yang sedang diproduksi, dan kecepatan yang berbeda untuk masing-masing maju di sepanjang itu. benar-benar cantik.
Will Ness
2

Namun tambahan lain untuk jawaban Edward Kmett: contoh mandiri:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Gunakan sebagai berikut untuk membuat memo fungsi dengan arg integer tunggal (misalnya fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Hanya nilai untuk argumen non-negatif yang akan di-cache.

Untuk juga menyimpan nilai untuk argumen negatif, gunakan memoInt, didefinisikan sebagai berikut:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Untuk menyimpan nilai-nilai untuk fungsi dengan dua argumen integer digunakan memoIntInt, didefinisikan sebagai berikut:

memoIntInt f = memoInt (\n -> memoInt (f n))
Neal Young
sumber
2

Solusi tanpa pengindeksan, dan tidak berdasarkan pada Edward KMETT.

Saya memfaktorkan keluar subpohon umum untuk orang tua yang sama ( f(n/4)dibagi antara f(n/2)dan f(n/4), dan f(n/6)dibagi antara f(2)dan f(3)). Dengan menyimpannya sebagai variabel tunggal di induk, penghitungan subpohon dilakukan satu kali.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

Kode tidak dengan mudah diperluas ke fungsi memoization umum (setidaknya, saya tidak tahu bagaimana melakukannya), dan Anda benar-benar harus memikirkan bagaimana subproblem tumpang tindih, tetapi strategi harus bekerja untuk beberapa parameter non-integer umum . (Saya memikirkannya untuk dua parameter string.)

Memo tersebut akan dibuang setelah setiap perhitungan. (Sekali lagi, saya memikirkan tentang dua parameter string.)

Saya tidak tahu apakah ini lebih efisien daripada jawaban lainnya. Setiap pencarian secara teknis hanya satu atau dua langkah ("Lihatlah anak Anda atau anak Anda"), tetapi mungkin ada banyak penggunaan memori tambahan.

Sunting: Solusi ini belum benar. Pembagiannya tidak lengkap.

Sunting: Seharusnya berbagi anak dengan benar sekarang, tetapi saya menyadari bahwa masalah ini memiliki banyak berbagi nontrivial: n/2/2/2dan n/3/3mungkin sama. Masalahnya tidak cocok untuk strategi saya.

leewz
sumber