Spark Dataframe aggregate with limited random samples

The name of the pictureThe name of the pictureThe name of the pictureClash Royale CLAN TAG#URR8PPP





.everyoneloves__top-leaderboard:empty,.everyoneloves__mid-leaderboard:empty margin-bottom:0;







up vote
0
down vote

favorite












If I want to aggregate by one column and sample a fixed number of samples for each key, one way to do it is first use collect_list(). With this approach, the job can run out of memory when some keys have very large number of values. Hence I wrote the following aggregator and have verified that it works. Please provide your feedback if there are improvements that can be made.



import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import scala.collection.mutable.WrappedArray,ListBuffer

class RandomSampleAggregator(limit: Int) extends UserDefinedAggregateFunction
var streamCount = 0 // count of values coming in for update

// Data types of input arguments of this aggregate function
def inputSchema: StructType = new StructType().add("value", StringType)

def bufferSchema: StructType = new StructType().add("values", ArrayType(StringType, true))

// The data type of the returned value
override def dataType: DataType = ArrayType(StringType, true)

def deterministic: Boolean = false

// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
def initialize(buffer: MutableAggregationBuffer): Unit =
buffer(0) = ListBuffer[String]()


// Updates the given aggregation buffer `buffer` with new input data from `input`
def update(buffer: MutableAggregationBuffer, input: Row): Unit =
streamCount += 1
if (!input.isNullAt(0))
val seq = buffer(0).asInstanceOf[WrappedArray[String]]
if (seq.length < limit)
buffer(0) = input.getAs[String](0) +: seq
else
// Reservoir sampling
val randIndex = scala.util.Random.nextInt(streamCount)
if (randIndex < limit)
val seq = buffer(0).asInstanceOf[WrappedArray[String]]
seq.update(randIndex, input.getAs[String](0))
buffer(0) = seq





def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
if (buffer1(0) != null && buffer2 != null && buffer2(0) != null)
val seq1 = buffer1(0).asInstanceOf[WrappedArray[String]]
val seq2 = buffer2(0).asInstanceOf[WrappedArray[String]]
if (seq1.length + seq2.length <= limit)
buffer1(0) = seq1 ++ seq2
else
buffer1(0) = scala.util.Random.shuffle(seq1 ++ seq2).take(limit)




def evaluate(buffer: Row): Any =
if (buffer(0) == null)
ListBuffer[String]()
else
buffer(0).asInstanceOf[WrappedArray[String]]









share|improve this question



























    up vote
    0
    down vote

    favorite












    If I want to aggregate by one column and sample a fixed number of samples for each key, one way to do it is first use collect_list(). With this approach, the job can run out of memory when some keys have very large number of values. Hence I wrote the following aggregator and have verified that it works. Please provide your feedback if there are improvements that can be made.



    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.MutableAggregationBuffer
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
    import org.apache.spark.sql.types._
    import scala.collection.mutable.WrappedArray,ListBuffer

    class RandomSampleAggregator(limit: Int) extends UserDefinedAggregateFunction
    var streamCount = 0 // count of values coming in for update

    // Data types of input arguments of this aggregate function
    def inputSchema: StructType = new StructType().add("value", StringType)

    def bufferSchema: StructType = new StructType().add("values", ArrayType(StringType, true))

    // The data type of the returned value
    override def dataType: DataType = ArrayType(StringType, true)

    def deterministic: Boolean = false

    // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
    // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
    // the opportunity to update its values. Note that arrays and maps inside the buffer are still
    // immutable.
    def initialize(buffer: MutableAggregationBuffer): Unit =
    buffer(0) = ListBuffer[String]()


    // Updates the given aggregation buffer `buffer` with new input data from `input`
    def update(buffer: MutableAggregationBuffer, input: Row): Unit =
    streamCount += 1
    if (!input.isNullAt(0))
    val seq = buffer(0).asInstanceOf[WrappedArray[String]]
    if (seq.length < limit)
    buffer(0) = input.getAs[String](0) +: seq
    else
    // Reservoir sampling
    val randIndex = scala.util.Random.nextInt(streamCount)
    if (randIndex < limit)
    val seq = buffer(0).asInstanceOf[WrappedArray[String]]
    seq.update(randIndex, input.getAs[String](0))
    buffer(0) = seq





    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
    if (buffer1(0) != null && buffer2 != null && buffer2(0) != null)
    val seq1 = buffer1(0).asInstanceOf[WrappedArray[String]]
    val seq2 = buffer2(0).asInstanceOf[WrappedArray[String]]
    if (seq1.length + seq2.length <= limit)
    buffer1(0) = seq1 ++ seq2
    else
    buffer1(0) = scala.util.Random.shuffle(seq1 ++ seq2).take(limit)




    def evaluate(buffer: Row): Any =
    if (buffer(0) == null)
    ListBuffer[String]()
    else
    buffer(0).asInstanceOf[WrappedArray[String]]









    share|improve this question























      up vote
      0
      down vote

      favorite









      up vote
      0
      down vote

      favorite











      If I want to aggregate by one column and sample a fixed number of samples for each key, one way to do it is first use collect_list(). With this approach, the job can run out of memory when some keys have very large number of values. Hence I wrote the following aggregator and have verified that it works. Please provide your feedback if there are improvements that can be made.



      import org.apache.spark.sql.Row
      import org.apache.spark.sql.expressions.MutableAggregationBuffer
      import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
      import org.apache.spark.sql.types._
      import scala.collection.mutable.WrappedArray,ListBuffer

      class RandomSampleAggregator(limit: Int) extends UserDefinedAggregateFunction
      var streamCount = 0 // count of values coming in for update

      // Data types of input arguments of this aggregate function
      def inputSchema: StructType = new StructType().add("value", StringType)

      def bufferSchema: StructType = new StructType().add("values", ArrayType(StringType, true))

      // The data type of the returned value
      override def dataType: DataType = ArrayType(StringType, true)

      def deterministic: Boolean = false

      // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
      // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
      // the opportunity to update its values. Note that arrays and maps inside the buffer are still
      // immutable.
      def initialize(buffer: MutableAggregationBuffer): Unit =
      buffer(0) = ListBuffer[String]()


      // Updates the given aggregation buffer `buffer` with new input data from `input`
      def update(buffer: MutableAggregationBuffer, input: Row): Unit =
      streamCount += 1
      if (!input.isNullAt(0))
      val seq = buffer(0).asInstanceOf[WrappedArray[String]]
      if (seq.length < limit)
      buffer(0) = input.getAs[String](0) +: seq
      else
      // Reservoir sampling
      val randIndex = scala.util.Random.nextInt(streamCount)
      if (randIndex < limit)
      val seq = buffer(0).asInstanceOf[WrappedArray[String]]
      seq.update(randIndex, input.getAs[String](0))
      buffer(0) = seq





      def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
      if (buffer1(0) != null && buffer2 != null && buffer2(0) != null)
      val seq1 = buffer1(0).asInstanceOf[WrappedArray[String]]
      val seq2 = buffer2(0).asInstanceOf[WrappedArray[String]]
      if (seq1.length + seq2.length <= limit)
      buffer1(0) = seq1 ++ seq2
      else
      buffer1(0) = scala.util.Random.shuffle(seq1 ++ seq2).take(limit)




      def evaluate(buffer: Row): Any =
      if (buffer(0) == null)
      ListBuffer[String]()
      else
      buffer(0).asInstanceOf[WrappedArray[String]]









      share|improve this question













      If I want to aggregate by one column and sample a fixed number of samples for each key, one way to do it is first use collect_list(). With this approach, the job can run out of memory when some keys have very large number of values. Hence I wrote the following aggregator and have verified that it works. Please provide your feedback if there are improvements that can be made.



      import org.apache.spark.sql.Row
      import org.apache.spark.sql.expressions.MutableAggregationBuffer
      import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
      import org.apache.spark.sql.types._
      import scala.collection.mutable.WrappedArray,ListBuffer

      class RandomSampleAggregator(limit: Int) extends UserDefinedAggregateFunction
      var streamCount = 0 // count of values coming in for update

      // Data types of input arguments of this aggregate function
      def inputSchema: StructType = new StructType().add("value", StringType)

      def bufferSchema: StructType = new StructType().add("values", ArrayType(StringType, true))

      // The data type of the returned value
      override def dataType: DataType = ArrayType(StringType, true)

      def deterministic: Boolean = false

      // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
      // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
      // the opportunity to update its values. Note that arrays and maps inside the buffer are still
      // immutable.
      def initialize(buffer: MutableAggregationBuffer): Unit =
      buffer(0) = ListBuffer[String]()


      // Updates the given aggregation buffer `buffer` with new input data from `input`
      def update(buffer: MutableAggregationBuffer, input: Row): Unit =
      streamCount += 1
      if (!input.isNullAt(0))
      val seq = buffer(0).asInstanceOf[WrappedArray[String]]
      if (seq.length < limit)
      buffer(0) = input.getAs[String](0) +: seq
      else
      // Reservoir sampling
      val randIndex = scala.util.Random.nextInt(streamCount)
      if (randIndex < limit)
      val seq = buffer(0).asInstanceOf[WrappedArray[String]]
      seq.update(randIndex, input.getAs[String](0))
      buffer(0) = seq





      def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
      if (buffer1(0) != null && buffer2 != null && buffer2(0) != null)
      val seq1 = buffer1(0).asInstanceOf[WrappedArray[String]]
      val seq2 = buffer2(0).asInstanceOf[WrappedArray[String]]
      if (seq1.length + seq2.length <= limit)
      buffer1(0) = seq1 ++ seq2
      else
      buffer1(0) = scala.util.Random.shuffle(seq1 ++ seq2).take(limit)




      def evaluate(buffer: Row): Any =
      if (buffer(0) == null)
      ListBuffer[String]()
      else
      buffer(0).asInstanceOf[WrappedArray[String]]











      share|improve this question












      share|improve this question




      share|improve this question








      edited Jul 31 at 17:10
























      asked Jul 13 at 2:25









      learnerer

      13




      13

























          active

          oldest

          votes











          Your Answer




          StackExchange.ifUsing("editor", function ()
          return StackExchange.using("mathjaxEditing", function ()
          StackExchange.MarkdownEditor.creationCallbacks.add(function (editor, postfix)
          StackExchange.mathjaxEditing.prepareWmdForMathJax(editor, postfix, [["\$", "\$"]]);
          );
          );
          , "mathjax-editing");

          StackExchange.ifUsing("editor", function ()
          StackExchange.using("externalEditor", function ()
          StackExchange.using("snippets", function ()
          StackExchange.snippets.init();
          );
          );
          , "code-snippets");

          StackExchange.ready(function()
          var channelOptions =
          tags: "".split(" "),
          id: "196"
          ;
          initTagRenderer("".split(" "), "".split(" "), channelOptions);

          StackExchange.using("externalEditor", function()
          // Have to fire editor after snippets, if snippets enabled
          if (StackExchange.settings.snippets.snippetsEnabled)
          StackExchange.using("snippets", function()
          createEditor();
          );

          else
          createEditor();

          );

          function createEditor()
          StackExchange.prepareEditor(
          heartbeatType: 'answer',
          convertImagesToLinks: false,
          noModals: false,
          showLowRepImageUploadWarning: true,
          reputationToPostImages: null,
          bindNavPrevention: true,
          postfix: "",
          onDemand: true,
          discardSelector: ".discard-answer"
          ,immediatelyShowMarkdownHelp:true
          );



          );








           

          draft saved


          draft discarded


















          StackExchange.ready(
          function ()
          StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fcodereview.stackexchange.com%2fquestions%2f198397%2fspark-dataframe-aggregate-with-limited-random-samples%23new-answer', 'question_page');

          );

          Post as a guest



































          active

          oldest

          votes













          active

          oldest

          votes









          active

          oldest

          votes






          active

          oldest

          votes










           

          draft saved


          draft discarded


























           


          draft saved


          draft discarded














          StackExchange.ready(
          function ()
          StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fcodereview.stackexchange.com%2fquestions%2f198397%2fspark-dataframe-aggregate-with-limited-random-samples%23new-answer', 'question_page');

          );

          Post as a guest













































































          Popular posts from this blog

          Chat program with C++ and SFML

          Function to Return a JSON Like Objects Using VBA Collections and Arrays

          Will my employers contract hold up in court?