Ubah kolom DataFrame spark ke daftar python

104

Saya mengerjakan kerangka data dengan dua kolom, mvv dan hitungan.

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

saya ingin mendapatkan dua daftar yang berisi nilai mvv dan nilai hitungan. Sesuatu seperti

mvv = [1,2,3,4]
count = [5,9,3,1]

Jadi, saya mencoba kode berikut: Baris pertama harus mengembalikan daftar baris python. Saya ingin melihat nilai pertama:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

Tetapi saya mendapatkan pesan kesalahan dengan baris kedua:

AttributeError: getInt

a.moussa
sumber
Pada Spark 2.3, kode ini adalah yang tercepat dan paling mungkin menyebabkan pengecualian OutOfMemory: list(df.select('mvv').toPandas()['mvv']). Panah diintegrasikan ke PySpark yang dipercepat toPandassecara signifikan. Jangan gunakan pendekatan lain jika Anda menggunakan Spark 2.3+. Lihat jawaban saya untuk detail pembandingan lebih lanjut.
Powers

Jawaban:

141

Lihat, mengapa cara yang Anda lakukan ini tidak berhasil. Pertama, Anda mencoba mendapatkan integer dari Row Type, output dari collect Anda adalah seperti ini:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

Jika Anda mengambil sesuatu seperti ini:

>>> firstvalue = mvv_list[0].mvv
Out: 1

Anda akan mendapatkan mvvnilainya. Jika Anda menginginkan semua informasi dari array Anda dapat mengambil sesuatu seperti ini:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

Tetapi jika Anda mencoba hal yang sama untuk kolom lain, Anda mendapatkan:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

Ini terjadi karena countmetode bawaan. Dan kolom tersebut memiliki nama yang sama dengan count. Solusi untuk melakukan ini adalah mengubah nama kolom countmenjadi _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

Tetapi solusi ini tidak diperlukan, karena Anda dapat mengakses kolom menggunakan sintaks kamus:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

Dan akhirnya akan berhasil!

Thiago Baldim
sumber
ini berfungsi dengan baik untuk kolom pertama, tetapi tidak berfungsi untuk jumlah kolom yang menurut saya karena (jumlah fungsi percikan)
a.moussa
Bisakah Anda menambahkan apa yang Anda lakukan dengan hitungan? Tambahkan di sini di komentar.
Thiago Baldim
terima kasih atas tanggapan Anda Jadi baris ini berfungsi mvv_list = [int (i.mvv) untuk saya di mvv_count.select ('mvv'). collect ()] tetapi bukan yang ini count_list = [int (i.count) untuk saya di mvv_count .select ('count'). collect ()] mengembalikan sintaks yang tidak valid
a.moussa
Tidak perlu menambahkan select('count')penggunaan ini seperti ini: count_list = [int(i.count) for i in mvv_list.collect()]Saya akan menambahkan contoh ke respons.
Thiago Baldim
1
@ a.moussa [i.['count'] for i in mvv_list.collect()]bekerja untuk membuatnya eksplisit menggunakan kolom bernama 'count' dan bukan countfungsinya
user989762
103

Mengikuti satu baris memberikan daftar yang Anda inginkan.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()
Neo
sumber
3
Dari segi kinerja, solusi ini jauh lebih cepat daripada solusi Anda mvv_list = [int (i.mvv) untuk i di mvv_count.select ('mvv'). Collect ()]
Chanaka Fernando
Sejauh ini, ini adalah solusi terbaik yang pernah saya lihat. Terima kasih.
hui chen
22

Ini akan memberi Anda semua elemen sebagai daftar.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)
Muhammad Raihan Muhaimin
sumber
1
Ini adalah solusi tercepat dan paling efisien untuk Spark 2.3+. Lihat hasil benchmarking di jawaban saya.
Powers
16

Kode berikut akan membantu Anda

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()
Itachi
sumber
3
Ini harus menjadi jawaban yang diterima. alasannya adalah Anda tetap berada dalam konteks percikan selama proses dan kemudian Anda mengumpulkan di akhir, bukan keluar dari konteks percikan lebih awal yang dapat menyebabkan pengumpulan yang lebih besar tergantung pada apa yang Anda lakukan.
AntiPawn79
15

Pada data saya, saya mendapatkan tolok ukur ini:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0,52 dtk

>>> [row[col] for row in data.collect()]

0,271 dtk

>>> list(data.select(col).toPandas()[col])

0,427 dtk

Hasilnya sama saja

luminousmen
sumber
1
Jika Anda menggunakan toLocalIteratoralih-alih collectitu seharusnya lebih hemat memori[row[col] for row in data.toLocalIterator()]
lihat
6

Jika Anda mendapatkan error di bawah ini:

AttributeError: objek 'list' tidak memiliki atribut 'kumpulkan'

Kode ini akan menyelesaikan masalah Anda:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]
anirban sen
sumber
Saya mendapatkan kesalahan itu juga dan solusi ini menyelesaikan masalah. Tetapi mengapa saya mendapatkan kesalahan? (Banyak orang lain sepertinya tidak mengerti!)
bikashg
3

Saya menjalankan analisis benchmarking dan list(mvv_count_df.select('mvv').toPandas()['mvv'])merupakan metode tercepat. Saya sangat terkejut.

Saya menjalankan pendekatan yang berbeda pada 100 ribu / 100 juta kumpulan data baris menggunakan cluster 5 node i3.xlarge (setiap node memiliki 30,5 GB RAM dan 4 core) dengan Spark 2.4.5. Data didistribusikan secara merata pada 20 file Parket terkompresi tajam dengan satu kolom.

Berikut hasil benchmarking (runtime dalam detik):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Aturan emas yang harus diikuti saat mengumpulkan data di node driver:

  • Cobalah untuk menyelesaikan masalah dengan pendekatan lain. Mengumpulkan data ke node driver itu mahal, tidak memanfaatkan kekuatan cluster Spark, dan harus dihindari jika memungkinkan.
  • Kumpulkan baris sesedikit mungkin. Gabungkan, hapus duplikat, filter, dan pangkas kolom sebelum mengumpulkan data. Kirim data sesedikit mungkin ke node driver.

toPandas meningkat secara signifikan di Spark 2.3 . Ini mungkin bukan pendekatan terbaik jika Anda menggunakan versi Spark lebih awal dari 2.3.

Lihat di sini untuk detail lebih lanjut / hasil pembandingan.

Powers
sumber
2

Solusi yang mungkin adalah menggunakan collect_list()fungsi dari pyspark.sql.functions. Ini akan menggabungkan semua nilai kolom menjadi larik pyspark yang diubah menjadi daftar python saat dikumpulkan:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
phgui
sumber