問題描述
我需要 pyspark.sql 函式的幫助,該函式將創建一個新變數,將指定 Window() 上的記錄聚合到鍵值對映射中。
可重現的資料
df = spark.createDataFrame(
[
('AK', "2022-05-02", 1651449600, 'US', 3),
('AK', "2022-05-03", 1651536000, 'ON', 1),
('AK', "2022-05-04", 1651622400, 'CO', 1),
('AK', "2022-05-06", 1651795200, 'AK', 1),
('AK', "2022-05-06", 1651795200, 'US', 5)
],
["state", "ds", "ds_num", "region", "count"]
)
df.show()
# ----- ---------- ---------- ------ -----
# |state| ds| ds_num|region|count|
# ----- ---------- ---------- ------ -----
# | AK|2022-05-02|1651449600| US| 3|
# | AK|2022-05-03|1651536000| ON| 1|
# | AK|2022-05-04|1651622400| CO| 1|
# | AK|2022-05-06|1651795200| AK| 1|
# | AK|2022-05-06|1651795200| US| 5|
# ----- ---------- ---------- ------ -----
部分解決方案
窗框上的區域集:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
days = lambda i: i * 86400
df.withColumn('regions_4W',
F.collect_set('region').over(
Window().partitionBy('state').orderBy('ds_num').rangeBetween(-days(27),0)))\
.sort('ds')\
.show()
# ----- ---------- ---------- ------ ----- ----------------
# |state| ds| ds_num|region|count| regions_4W|
# ----- ---------- ---------- ------ ----- ----------------
# | AK|2022-05-02|1651449600| US| 3| [US]|
# | AK|2022-05-03|1651536000| ON| 1| [US, ON]|
# | AK|2022-05-04|1651622400| CO| 1| [CO, US, ON]|
# | AK|2022-05-06|1651795200| AK| 1|[CO, US, ON, AK]|
# | AK|2022-05-06|1651795200| US| 5|[CO, US, ON, AK]|
# ----- ---------- ---------- ------ ----- ----------------
每個狀態和 ds 的計數圖
df\
.groupby('state', 'ds', 'ds_num')\
.agg(F.map_from_entries(F.collect_list(F.struct("region", "count"))).alias("count_rolling_4W"))\
.sort('ds')\
.show()
# ----- ---------- ---------- ------------------
# |state| ds| ds_num| count_rolling_4W|
# ----- ---------- ---------- ------------------
# | AK|2022-05-02|1651449600| {US -> 3}|
# | AK|2022-05-03|1651536000| {ON -> 1}|
# | AK|2022-05-04|1651622400| {CO -> 1}|
# | AK|2022-05-06|1651795200|{AK -> 1, US -> 5}|
# ----- ---------- ---------- ------------------
期望的輸出
我需要的是一個地圖,為指定視窗中存在的每個鍵聚合資料
----- ---------- ---------- ------------------------------------
|state| ds| ds_num| count_rolling_4W|
----- ---------- ---------- ------------------------------------
| AK|2022-05-02|1651449600| {US -> 3}|
| AK|2022-05-03|1651536000| {US -> 3, ON -> 1}|
| AK|2022-05-04|1651622400| {US -> 3, ON -> 1, CO -> 1}|
| AK|2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
----- ---------- ---------- ------------------------------------
uj5u.com熱心網友回復:
您可以使用高階函式transform
,aggregate
如下所示:
from pyspark.sql import Window, functions as F
w = Window.partitionBy('state').orderBy('ds_num').rowsBetween(-days(27), 0)
df1 = (df.withColumn('regions', F.collect_set('region').over(w))
.withColumn('counts', F.collect_list(F.struct('region', 'count')).over(w))
.withColumn('counts',
F.transform(
'regions',
lambda x: F.aggregate(
F.filter('counts', lambda y: y['region'] == x),
F.lit(0),
lambda acc, v: acc v['count']
)
))
.withColumn('count_rolling_4W', F.map_from_arrays('regions', 'counts'))
.drop('counts', 'regions')
)
df1.show(truncate=False)
# ----- ---------- ---------- ------ ----- ------------------------------------
# |state|ds |ds_num |region|count|count_rolling_4W |
# ----- ---------- ---------- ------ ----- ------------------------------------
# |AK |2022-05-02|1651449600|US |3 |{US -> 3} |
# |AK |2022-05-03|1651536000|ON |1 |{US -> 3, ON -> 1} |
# |AK |2022-05-04|1651622400|CO |1 |{CO -> 1, US -> 3, ON -> 1} |
# |AK |2022-05-06|1651795200|AK |1 |{CO -> 1, US -> 3, ON -> 1, AK -> 1}|
# |AK |2022-05-06|1651795200|US |5 |{CO -> 1, US -> 8, ON -> 1, AK -> 1}|
# ----- ---------- ---------- ------ ----- ------------------------------------
uj5u.com熱心網友回復:
好問題。此方法將使用 2 個視窗和 2 個高階函式 (aggregate
和map_from_entries
)
from pyspark.sql import functions as F, Window
w1 = Window.partitionBy('state', 'region').orderBy('ds')
w2 = Window.partitionBy('state').orderBy('ds')
df = df.withColumn('acc_count', F.sum('count').over(w1))
df = df.withColumn('maps', F.collect_set(F.struct('region', 'acc_count')).over(w2))
df = df.groupBy('state', 'ds', 'ds_num').agg(
F.aggregate(
F.first('maps'),
F.create_map(F.first('region'), F.first('acc_count')),
lambda m, x: F.map_zip_with(m, F.map_from_entries(F.array(x)), lambda k, v1, v2: F.greatest(v2, v1))
).alias('count_rolling_4W')
)
df.show(truncate=0)
# ----- ---------- ---------- ------------------------------------
# |state|ds |ds_num |count_rolling_4W |
# ----- ---------- ---------- ------------------------------------
# |AK |2022-05-02|1651449600|{US -> 3} |
# |AK |2022-05-03|1651536000|{ON -> 1, US -> 3} |
# |AK |2022-05-04|1651622400|{CO -> 1, US -> 3, ON -> 1} |
# |AK |2022-05-06|1651795200|{AK -> 1, US -> 8, ON -> 1, CO -> 1}|
# ----- ---------- ---------- ------------------------------------
uj5u.com熱心網友回復:
假設源資料框中的、 和state
列ds
是唯一的(它們可以被視為主鍵),則此代碼段可以完成作業:ds_num
region
import pyspark.sql.functions as F
from pyspark.sql.window import Window
days = lambda i: i * 86400
df.alias('a').join(df.alias('b'), 'state') \
.where((F.col('a.ds_num') - F.col('b.ds_num')).between(0, days(27))) \
.select('state', 'a.ds', 'a.ds_num', 'b.region', 'b.count') \
.dropDuplicates() \
.groupBy('state', 'ds', 'ds_num', 'region').sum('count') \
.groupBy('state', 'ds', 'ds_num') \
.agg(F.map_from_entries(F.collect_list(F.struct("region", "sum(count)"))).alias("count_rolling_4W")) \
.orderBy('a.ds') \
.show(truncate=False)
結果:
----- ---------- ---------- ------------------------------------
|state|ds |ds_num |count_rolling_4W |
----- ---------- ---------- ------------------------------------
|AK |2022-05-02|1651449600|{US -> 3} |
|AK |2022-05-03|1651536000|{US -> 3, ON -> 1} |
|AK |2022-05-04|1651622400|{US -> 3, ON -> 1, CO -> 1} |
|AK |2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
----- ---------- ---------- ------------------------------------
它可能看起來很復雜,但它只是將視窗重寫為交叉連接,以便更好地控制結果。
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/496482.html