/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.plans.physical

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
 * Specifies how tuples that share common expressions will be distributed when a query is executed
 * in parallel on many machines.
 *
 * Distribution here refers to inter-node partitioning of data. That is, it describes how tuples
 * are partitioned across physical machines in a cluster. Knowing this property allows some
 * operators (e.g., Aggregate) to perform partition local operations instead of global ones.
 */
sealed trait Distribution {
  /**
   * The required number of partitions for this distribution. If it's None, then any number of
   * partitions is allowed for this distribution.
   */
  def requiredNumPartitions: Option[Int]

  /**
   * Creates a default partitioning for this distribution, which can satisfy this distribution while
   * matching the given number of partitions.
   */
  def createPartitioning(numPartitions: Int): Partitioning
}

/**
 * Represents a distribution where no promises are made about co-location of data.
 */
case object UnspecifiedDistribution extends Distribution {
  override def requiredNumPartitions: Option[Int] = None

  override def createPartitioning(numPartitions: Int): Partitioning = {
    throw SparkException.internalError(
      "UnspecifiedDistribution does not have default partitioning.")
  }
}

/**
 * Represents a distribution that only has a single partition and all tuples of the dataset
 * are co-located.
 */
case object AllTuples extends Distribution {
  override def requiredNumPartitions: Option[Int] = Some(1)

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.")
    SinglePartition
  }
}

/**
 * Represents data where tuples that share the same values for the `clustering`
 * [[Expression Expressions]] will be co-located in the same partition.
 *
 * @param requireAllClusterKeys When true, `Partitioning` which satisfies this distribution,
 *                              must match all `clustering` expressions in the same ordering.
 */
case class ClusteredDistribution(
    clustering: Seq[Expression],
    requireAllClusterKeys: Boolean = SQLConf.get.getConf(
      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION),
    requiredNumPartitions: Option[Int] = None) extends Distribution {
  require(
    clustering != Nil,
    "The clustering expressions of a ClusteredDistribution should not be Nil. " +
      "An AllTuples should be used to represent a distribution that only has " +
      "a single partition.")

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
      s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
        s"the actual number of partitions is $numPartitions.")
    HashPartitioning(clustering, numPartitions)
  }

  /**
   * Checks if `expressions` match all `clustering` expressions in the same ordering.
   *
   * `Partitioning` should call this to check its expressions when `requireAllClusterKeys`
   * is set to true.
   */
  def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = {
    expressions.length == clustering.length &&
      expressions.zip(clustering).forall {
        case (l, r) => l.semanticEquals(r)
      }
  }
}

/**
 * Represents the requirement of distribution on the stateful operator in Structured Streaming.
 *
 * Each partition in stateful operator initializes state store(s), which are independent with state
 * store(s) in other partitions. Since it is not possible to repartition the data in state store,
 * Spark should make sure the physical partitioning of the stateful operator is unchanged across
 * Spark versions. Violation of this requirement may bring silent correctness issue.
 *
 * Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the
 * stateful operator, only [[HashPartitioning]] (and HashPartitioning in
 * [[PartitioningCollection]]) can satisfy this distribution.
 * When `_requiredNumPartitions` is 1, [[SinglePartition]] is essentially same as
 * [[HashPartitioning]], so it can satisfy this distribution as well.
 *
 * NOTE: This is applied only to stream-stream join as of now. For other stateful operators, we
 * have been using ClusteredDistribution, which could construct the physical partitioning of the
 * state in different way (ClusteredDistribution requires relaxed condition and multiple
 * partitionings can satisfy the requirement.) We need to construct the way to fix this with
 * minimizing possibility to break the existing checkpoints.
 *
 * TODO(SPARK-38204): address the issue explained in above note.
 */
case class StatefulOpClusteredDistribution(
    expressions: Seq[Expression],
    _requiredNumPartitions: Int) extends Distribution {
  require(
    expressions != Nil,
    "The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " +
      "An AllTuples should be used to represent a distribution that only has " +
      "a single partition.")

  override val requiredNumPartitions: Option[Int] = Some(_requiredNumPartitions)

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(_requiredNumPartitions == numPartitions,
      s"This StatefulOpClusteredDistribution requires ${_requiredNumPartitions} " +
        s"partitions, but the actual number of partitions is $numPartitions.")
    HashPartitioning(expressions, numPartitions)
  }
}

/**
 * Represents data where tuples have been ordered according to the `ordering`
 * [[Expression Expressions]]. Its requirement is defined as the following:
 *   - Given any 2 adjacent partitions, all the rows of the second partition must be larger than or
 *     equal to any row in the first partition, according to the `ordering` expressions.
 *
 * In other words, this distribution requires the rows to be ordered across partitions, but not
 * necessarily within a partition.
 */
case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
  require(
    ordering != Nil,
    "The ordering expressions of an OrderedDistribution should not be Nil. " +
      "An AllTuples should be used to represent a distribution that only has " +
      "a single partition.")

  override def requiredNumPartitions: Option[Int] = None

  override def createPartitioning(numPartitions: Int): Partitioning = {
    RangePartitioning(ordering, numPartitions)
  }

  def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = {
    expressions.length == ordering.length &&
      expressions.zip(ordering).forall {
        case (x, o) => x.semanticEquals(o.child)
      }
  }
}

/**
 * Represents data where tuples are broadcasted to every node. It is quite common that the
 * entire set of tuples is transformed into different data structure.
 */
case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
  override def requiredNumPartitions: Option[Int] = Some(1)

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(numPartitions == 1,
      "The default partitioning of BroadcastDistribution can only have 1 partition.")
    BroadcastPartitioning(mode)
  }
}

/**
 * Describes how an operator's output is split across partitions. It has 2 major properties:
 *   1. number of partitions.
 *   2. if it can satisfy a given distribution.
 */
trait Partitioning {
  /** Returns the number of partitions that the data is split across */
  val numPartitions: Int

  /**
   * Returns true iff the guarantees made by this [[Partitioning]] are sufficient
   * to satisfy the partitioning scheme mandated by the `required` [[Distribution]],
   * i.e. the current dataset does not need to be re-partitioned for the `required`
   * Distribution (it is possible that tuples within a partition need to be reorganized).
   *
   * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` doesn't match
   * [[Distribution.requiredNumPartitions]].
   */
  final def satisfies(required: Distribution): Boolean = {
    required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
  }

  /**
   * Creates a shuffle spec for this partitioning and its required distribution. The
   * spec is used in the scenario where an operator has multiple children (e.g., join), and is
   * used to decide whether this child is co-partitioned with others, therefore whether extra
   * shuffle shall be introduced.
   *
   * @param distribution the required clustered distribution for this partitioning
   */
  def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
    throw SparkException.internalError(s"Unexpected partitioning: ${getClass.getSimpleName}")

  /**
   * The actual method that defines whether this [[Partitioning]] can satisfy the given
   * [[Distribution]], after the `numPartitions` check.
   *
   * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if
   * the [[Partitioning]] only have one partition. Implementations can also overwrite this method
   * with special logic.
   */
  protected def satisfies0(required: Distribution): Boolean = required match {
    case UnspecifiedDistribution => true
    case AllTuples => numPartitions == 1
    case _ => false
  }
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning

/**
 * Represents a partitioning where rows are distributed evenly across output partitions
 * by starting from a random target partition number and distributing rows in a round-robin
 * fashion. This partitioning is used when implementing the DataFrame.repartition() operator.
 */
case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning

case object SinglePartition extends Partitioning {
  val numPartitions = 1

  override def satisfies0(required: Distribution): Boolean = required match {
    case _: BroadcastDistribution => false
    case _ => true
  }

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
    SinglePartitionShuffleSpec
}

trait HashPartitioningLike extends Expression with Partitioning with Unevaluable {
  def expressions: Seq[Expression]

  override def children: Seq[Expression] = expressions
  override def nullable: Boolean = false
  override def dataType: DataType = IntegerType

  override def satisfies0(required: Distribution): Boolean = {
    super.satisfies0(required) || {
      required match {
        case h: StatefulOpClusteredDistribution =>
          expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
            case (l, r) => l.semanticEquals(r)
          }
        case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
          if (requireAllClusterKeys) {
            // Checks `HashPartitioning` is partitioned on exactly same clustering keys of
            // `ClusteredDistribution`.
            c.areAllClusterKeysMatched(expressions)
          } else {
            expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
          }
        case _ => false
      }
    }
  }
}

/**
 * Represents a partitioning where rows are split up across partitions based on the hash
 * of `expressions`.  All rows where `expressions` evaluate to the same values are guaranteed to be
 * in the same partition.
 *
 * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires
 * stateful operators to retain the same physical partitioning during the lifetime of the query
 * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged
 * across Spark versions. Violation of this requirement may bring silent correctness issue.
 */
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
  extends HashPartitioningLike {

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
    HashShuffleSpec(this, distribution)

  /**
   * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
   * than numPartitions) based on hashing expressions.
   */
  def partitionIdExpression: Expression = Pmod(
    new CollationAwareMurmur3Hash(expressions), Literal(numPartitions)
  )

  override protected def withNewChildrenInternal(
    newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren)
}

case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)

/**
 * Represents a partitioning where partitions have been coalesced from a HashPartitioning into a
 * fewer number of partitions.
 */
case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[CoalescedBoundary])
  extends HashPartitioningLike {

  override def expressions: Seq[Expression] = from.expressions

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
    CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)

  override val numPartitions: Int = partitions.length

  override protected def withNewChildrenInternal(
      newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning =
    copy(from = from.copy(expressions = newChildren))
}

/**
 * Represents a partitioning where rows are split across partitions based on transforms defined
 * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in
 * ascending order, after evaluated by the transforms in `expressions`, for each input partition.
 * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1
 * mapping), and each row in `partitionValues` must be unique.
 *
 * The `originalPartitionValues`, on the other hand, are partition values from the original input
 * splits returned by data sources. It may contain duplicated values.
 *
 * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4
 * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions`
 * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which
 * represents 3 input partitions with distinct partition values. All rows in each partition have
 * the same value for column `ts_col` (which is of timestamp type), after being applied by the
 * `years` transform. This is generated after combining the two splits with partition value `2`
 * into a single Spark partition.
 *
 * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues`
 * which is calculated from the original input splits.
 *
 * @param expressions partition expressions for the partitioning.
 * @param numPartitions the number of partitions
 * @param partitionValues the values for the final cluster keys (that is, after applying grouping
 *                        on the input splits according to `expressions`) of the distribution,
 *                        must be in ascending order, and must NOT contain duplicated values.
 * @param originalPartitionValues the original input partition values before any grouping has been
 *                                applied, must be in ascending order, and may contain duplicated
 *                                values
 */
case class KeyGroupedPartitioning(
    expressions: Seq[Expression],
    numPartitions: Int,
    partitionValues: Seq[InternalRow] = Seq.empty,
    originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {

  override def satisfies0(required: Distribution): Boolean = {
    super.satisfies0(required) || {
      required match {
        case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
          if (requireAllClusterKeys) {
            // Checks whether this partitioning is partitioned on exactly same clustering keys of
            // `ClusteredDistribution`.
            c.areAllClusterKeysMatched(expressions)
          } else {
            // We'll need to find leaf attributes from the partition expressions first.
            val attributes = expressions.flatMap(_.collectLeaves())

            if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
              // check that join keys (required clustering keys)
              // overlap with partition keys (KeyGroupedPartitioning attributes)
              requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
                  expressions.forall(_.collectLeaves().size == 1)
            } else {
              attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
            }
          }

        case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
          o.areAllClusterKeysMatched(expressions)

        case _ =>
          false
      }
    }
  }

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
    val result = KeyGroupedShuffleSpec(this, distribution)
    if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
      // If allowing join keys to be subset of clustering keys, we should create a new
      // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
      // the returned shuffle spec.
      val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
      val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
          partitionValues, originalPartitionValues)
      result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
    } else {
      result
    }
  }

  lazy val uniquePartitionValues: Seq[InternalRow] = {
    partitionValues
        .map(InternalRowComparableWrapper(_, expressions))
        .distinct
        .map(_.row)
  }

  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
    copy(expressions = newChildren)
}

object KeyGroupedPartitioning {
  def apply(
      expressions: Seq[Expression],
      projectionPositions: Seq[Int],
      partitionValues: Seq[InternalRow],
      originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
    val projectedExpressions = projectionPositions.map(expressions(_))
    val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
    val projectedOriginalPartitionValues =
      originalPartitionValues.map(project(expressions, projectionPositions, _))

    val finalPartitionValues = projectedPartitionValues
        .map(InternalRowComparableWrapper(_, projectedExpressions))
        .distinct
        .map(_.row)

    KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
      finalPartitionValues, projectedOriginalPartitionValues)
  }

  def project(
      expressions: Seq[Expression],
      positions: Seq[Int],
      input: InternalRow): InternalRow = {
    val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
      .toArray
    new GenericInternalRow(projectedValues)
  }

  def supportsExpressions(expressions: Seq[Expression]): Boolean = {
    def isSupportedTransform(transform: TransformExpression): Boolean = {
      transform.children.size == 1 && isReference(transform.children.head)
    }

    @tailrec
    def isReference(e: Expression): Boolean = e match {
      case _: Attribute => true
      case g: GetStructField => isReference(g.child)
      case _ => false
    }

    expressions.forall {
      case t: TransformExpression if isSupportedTransform(t) => true
      case e: Expression if isReference(e) => true
      case _ => false
    }
  }
}

/**
 * Represents a partitioning where rows are split across partitions based on some total ordering of
 * the expressions specified in `ordering`.  When data is partitioned in this manner, it guarantees:
 * Given any 2 adjacent partitions, all the rows of the second partition must be larger than any row
 * in the first partition, according to the `ordering` expressions.
 *
 * This is a strictly stronger guarantee than what `OrderedDistribution(ordering)` requires, as
 * there is no overlap between partitions.
 *
 * This class extends expression primarily so that transformations over expression will descend
 * into its child.
 */
case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
  extends Expression with Partitioning with Unevaluable {

  override def children: Seq[SortOrder] = ordering
  override def nullable: Boolean = false
  override def dataType: DataType = IntegerType

  override def satisfies0(required: Distribution): Boolean = {
    super.satisfies0(required) || {
      required match {
        case OrderedDistribution(requiredOrdering) =>
          // If `ordering` is a prefix of `requiredOrdering`:
          //   Let's say `ordering` is [a, b] and `requiredOrdering` is [a, b, c]. According to the
          //   RangePartitioning definition, any [a, b] in a previous partition must be smaller
          //   than any [a, b] in the following partition. This also means any [a, b, c] in a
          //   previous partition must be smaller than any [a, b, c] in the following partition.
          //   Thus `RangePartitioning(a, b)` satisfies `OrderedDistribution(a, b, c)`.
          //
          // If `requiredOrdering` is a prefix of `ordering`:
          //   Let's say `ordering` is [a, b, c] and `requiredOrdering` is [a, b]. According to the
          //   RangePartitioning definition, any [a, b, c] in a previous partition must be smaller
          //   than any [a, b, c] in the following partition. If there is a [a1, b1] from a previous
          //   partition which is larger than a [a2, b2] from the following partition, then there
          //   must be a [a1, b1 c1] larger than [a2, b2, c2], which violates RangePartitioning
          //   definition. So it's guaranteed that, any [a, b] in a previous partition must not be
          //   greater(i.e. smaller or equal to) than any [a, b] in the following partition. Thus
          //   `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`.
          val minSize = Seq(requiredOrdering.size, ordering.size).min
          requiredOrdering.take(minSize) == ordering.take(minSize)
        case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
          val expressions = ordering.map(_.child)
          if (requireAllClusterKeys) {
            // Checks `RangePartitioning` is partitioned on exactly same clustering keys of
            // `ClusteredDistribution`.
            c.areAllClusterKeysMatched(expressions)
          } else {
            expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
          }
        case _ => false
      }
    }
  }

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
    RangeShuffleSpec(this.numPartitions, distribution)

  override protected def withNewChildrenInternal(
      newChildren: IndexedSeq[Expression]): RangePartitioning =
    copy(ordering = newChildren.asInstanceOf[Seq[SortOrder]])
}

/**
 * A collection of [[Partitioning]]s that can be used to describe the partitioning
 * scheme of the output of a physical operator. It is usually used for an operator
 * that has multiple children. In this case, a [[Partitioning]] in this collection
 * describes how this operator's output is partitioned based on expressions from
 * a child. For example, for a Join operator on two tables `A` and `B`
 * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema,
 * there are two [[Partitioning]]s can be used to describe how the output of
 * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and
 * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
 * in this collection do not need to be equivalent, which is useful for
 * Outer Join operators.
 */
case class PartitioningCollection(partitionings: Seq[Partitioning])
  extends Expression with Partitioning with Unevaluable {

  require(
    partitionings.map(_.numPartitions).distinct.length == 1,
    s"PartitioningCollection requires all of its partitionings have the same numPartitions.")

  override def children: Seq[Expression] = partitionings.collect {
    case expr: Expression => expr
  }

  override def nullable: Boolean = false

  override def dataType: DataType = IntegerType

  override val numPartitions = partitionings.map(_.numPartitions).distinct.head

  /**
   * Returns true if any `partitioning` of this collection satisfies the given
   * [[Distribution]].
   */
  override def satisfies0(required: Distribution): Boolean =
    partitionings.exists(_.satisfies(required))

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
    val filtered = partitionings.filter(_.satisfies(distribution))
    ShuffleSpecCollection(filtered.map(_.createShuffleSpec(distribution)))
  }

  override def toString: String = {
    partitionings.map(_.toString).mkString("(", " or ", ")")
  }

  override protected def withNewChildrenInternal(
      newChildren: IndexedSeq[Expression]): PartitioningCollection =
    super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection]
}

/**
 * Represents a partitioning where rows are collected, transformed and broadcasted to each
 * node in the cluster.
 */
case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
  override val numPartitions: Int = 1

  override def satisfies0(required: Distribution): Boolean = required match {
    case UnspecifiedDistribution => true
    case BroadcastDistribution(m) if m == mode => true
    case _ => false
  }
}

/**
 * This is used in the scenario where an operator has multiple children (e.g., join) and one or more
 * of which have their own requirement regarding whether its data can be considered as
 * co-partitioned from others. This offers APIs for:
 *
 *   - Comparing with specs from other children of the operator and check if they are compatible.
 *      When two specs are compatible, we can say their data are co-partitioned, and Spark will
 *      potentially be able to eliminate shuffle if necessary.
 *   - Creating a partitioning that can be used to re-partition another child, so that to make it
 *      having a compatible partitioning as this node.
 */

/**
 * Represents a partitioning where partition IDs are passed through directly from the
 * DirectShufflePartitionID expression. This partitioning scheme is used when users
 * want to directly control partition placement rather than using hash-based partitioning.
 *
 * This partitioning maps directly to the PartitionIdPassthrough RDD partitioner.
 */
case class ShufflePartitionIdPassThrough(
    expr: DirectShufflePartitionID,
    numPartitions: Int) extends Expression with Partitioning with Unevaluable {

  override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
    ShufflePartitionIdPassThroughSpec(this, distribution)
  }

  def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions))

  def expressions: Seq[Expression] = expr :: Nil
  override def children: Seq[Expression] = expr :: Nil
  override def nullable: Boolean = false
  override def dataType: DataType = IntegerType

  override def satisfies0(required: Distribution): Boolean = {
    super.satisfies0(required) || {
      required match {
        // TODO(SPARK-53428): Support Direct Passthrough Partitioning in the Streaming Joins
        case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
          val partitioningExpressions = expr.child :: Nil
          if (requireAllClusterKeys) {
            c.areAllClusterKeysMatched(partitioningExpressions)
          } else {
            partitioningExpressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
          }
        case _ => false
      }
    }
  }

  override protected def withNewChildrenInternal(
      newChildren: IndexedSeq[Expression]): ShufflePartitionIdPassThrough =
    copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID])
}

trait ShuffleSpec {
  /**
   * Returns the number of partitions of this shuffle spec
   */
  def numPartitions: Int

  /**
   * Returns true iff this spec is compatible with the provided shuffle spec.
   *
   * A true return value means that the data partitioning from this spec can be seen as
   * co-partitioned with the `other`, and therefore no shuffle is required when joining the two
   * sides.
   *
   * Note that Spark assumes this to be reflexive, symmetric and transitive.
   */
  def isCompatibleWith(other: ShuffleSpec): Boolean

  /**
   * Whether this shuffle spec can be used to create partitionings for the other children.
   */
  def canCreatePartitioning: Boolean

  /**
   * Creates a partitioning that can be used to re-partition the other side with the given
   * clustering expressions.
   *
   * This will only be called when:
   *  - [[isCompatibleWith]] returns false on the side where the `clustering` is from.
   */
  def createPartitioning(clustering: Seq[Expression]): Partitioning =
    throw SparkUnsupportedOperationException()
}

case object SinglePartitionShuffleSpec extends ShuffleSpec {
  override def isCompatibleWith(other: ShuffleSpec): Boolean = {
    other.numPartitions == 1
  }

  override def canCreatePartitioning: Boolean = false

  override def createPartitioning(clustering: Seq[Expression]): Partitioning =
    SinglePartition

  override def numPartitions: Int = 1
}

case class RangeShuffleSpec(
    numPartitions: Int,
    distribution: ClusteredDistribution) extends ShuffleSpec {

  // `RangePartitioning` is not compatible with any other partitioning since it can't guarantee
  // data are co-partitioned for all the children, as range boundaries are randomly sampled. We
  // can't let `RangeShuffleSpec` to create a partitioning.
  override def canCreatePartitioning: Boolean = false

  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
    case SinglePartitionShuffleSpec => numPartitions == 1
    case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith)
    // `RangePartitioning` is not compatible with any other partitioning since it can't guarantee
    // data are co-partitioned for all the children, as range boundaries are randomly sampled.
    case _ => false
  }
}

case class HashShuffleSpec(
    partitioning: HashPartitioning,
    distribution: ClusteredDistribution) extends ShuffleSpec {

  /**
   * A sequence where each element is a set of positions of the hash partition key to the cluster
   * keys. For instance, if cluster keys are [a, b, b] and hash partition keys are [a, b], the
   * result will be [(0), (1, 2)].
   *
   * This is useful to check compatibility between two `HashShuffleSpec`s. If the cluster keys are
   * [a, b, b] and [x, y, z] for the two join children, and the hash partition keys are
   * [a, b] and [x, z], they are compatible. With the positions, we can do the compatibility check
   * by looking at if the positions of hash partition keys from two sides have overlapping.
   */
  lazy val hashKeyPositions: Seq[mutable.BitSet] = {
    val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
    distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
      distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
    }
    partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized, mutable.BitSet.empty))
  }

  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
    case SinglePartitionShuffleSpec =>
      partitioning.numPartitions == 1
    case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution) =>
      // we need to check:
      //  1. both distributions have the same number of clustering expressions
      //  2. both partitioning have the same number of partitions
      //  3. both partitioning have the same number of expressions
      //  4. each pair of partitioning expression from both sides has overlapping positions in their
      //     corresponding distributions.
      distribution.clustering.length == otherDistribution.clustering.length &&
      partitioning.numPartitions == otherPartitioning.numPartitions &&
      partitioning.expressions.length == otherPartitioning.expressions.length && {
        val otherHashKeyPositions = otherHashSpec.hashKeyPositions
        hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, right) =>
          left.intersect(right).nonEmpty
        }
      }
    case ShuffleSpecCollection(specs) =>
      specs.exists(isCompatibleWith)
    case _ =>
      false
  }

  override def canCreatePartitioning: Boolean = {
    // To avoid potential data skew, we don't allow `HashShuffleSpec` to create partitioning if
    // the hash partition keys are not the full join keys (the cluster keys). Then the planner
    // will add shuffles with the default partitioning of `ClusteredDistribution`, which uses all
    // the join keys.
    if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
      distribution.areAllClusterKeysMatched(partitioning.expressions)
    } else {
      true
    }
  }

  override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
    val exprs = hashKeyPositions.map(v => clustering(v.head))
    HashPartitioning(exprs, partitioning.numPartitions)
  }

  override def numPartitions: Int = partitioning.numPartitions
}

case class CoalescedHashShuffleSpec(
    from: ShuffleSpec,
    partitions: Seq[CoalescedBoundary]) extends ShuffleSpec {

  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
    case SinglePartitionShuffleSpec =>
      numPartitions == 1
    case CoalescedHashShuffleSpec(otherParent, otherPartitions) =>
      partitions == otherPartitions && from.isCompatibleWith(otherParent)
    case ShuffleSpecCollection(specs) =>
      specs.exists(isCompatibleWith)
    case _ =>
      false
  }

  override def canCreatePartitioning: Boolean = false

  override def numPartitions: Int = partitions.length
}

/**
 * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
 *
 * @param partitioning key grouped partitioning
 * @param distribution distribution
 * @param joinKeyPositions position of join keys among cluster keys.
 *                         This is set if joining on a subset of cluster keys is allowed.
 */
case class KeyGroupedShuffleSpec(
    partitioning: KeyGroupedPartitioning,
    distribution: ClusteredDistribution,
    joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {

  /**
   * A sequence where each element is a set of positions of the partition expression to the cluster
   * keys. For instance, if cluster keys are [a, b, b] and partition expressions are
   * [bucket(4, a), years(b)], the result will be [(0), (1, 2)].
   *
   * Note that we only allow each partition expression to contain a single partition key.
   * Therefore the mapping here is very similar to that from `HashShuffleSpec`.
   */
  lazy val keyPositions: Seq[mutable.BitSet] = {
    val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
    distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
      distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
    }
    partitioning.expressions.map { e =>
      val leaves = e.collectLeaves()
      assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}")
      distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
    }
  }

  override def numPartitions: Int = partitioning.numPartitions

  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
    // Here we check:
    //  1. both distributions have the same number of clustering keys
    //  2. both partitioning have the same number of partitions
    //  3. partition expressions from both sides are compatible, which means:
    //    3.1 both sides have the same number of partition expressions
    //    3.2 for each pair of partition expressions at the same index, the corresponding
    //        partition keys must share overlapping positions in their respective clustering keys.
    //    3.3 each pair of partition expressions at the same index must share compatible
    //        transform functions.
    //  4. the partition values from both sides are following the same order.
    case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
      distribution.clustering.length == otherDistribution.clustering.length &&
        numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
          partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
            case (left, right) =>
              InternalRowComparableWrapper(left, partitioning.expressions)
                .equals(InternalRowComparableWrapper(right, partitioning.expressions))
          }
    case ShuffleSpecCollection(specs) =>
      specs.exists(isCompatibleWith)
    case _ => false
  }

  // Whether the partition keys (i.e., partition expressions) are compatible between this and the
  // `other` spec.
  def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
    val expressions = partitioning.expressions
    val otherExpressions = other.partitioning.expressions

    expressions.length == otherExpressions.length && {
      val otherKeyPositions = other.keyPositions
      keyPositions.zip(otherKeyPositions).forall { case (left, right) =>
        left.intersect(right).nonEmpty
      }
    } && expressions.zip(otherExpressions).forall {
      case (l, r) => isExpressionCompatible(l, r)
    }
  }

  private def isExpressionCompatible(left: Expression, right: Expression): Boolean =
    (left, right) match {
      case (_: LeafExpression, _: LeafExpression) => true
      case (left: TransformExpression, right: TransformExpression) =>
        if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
          !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
          SQLConf.get.v2BucketingAllowCompatibleTransforms) {
          left.isCompatible(right)
        } else {
          left.isSameFunction(right)
        }
      case _ => false
    }

  /**
   * Return a set of [[Reducer]] for the partition expressions of this shuffle spec,
   * on the partition expressions of another shuffle spec.
   * <p>
   * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
   * 'reducible' on the corresponding partition expression function of the other shuffle spec.
   * <p>
   * If a value is returned, there must be one [[Reducer]] per partition expression.
   * A None value in the set indicates that the particular partition expression is not reducible
   * on the corresponding expression on the other shuffle spec.
   * <p>
   * Returning none also indicates that none of the partition expressions can be reduced on the
   * corresponding expression on the other shuffle spec.
   *
   * @param other other key-grouped shuffle spec
   */
  def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
     val results = partitioning.expressions.zip(other.partitioning.expressions).map {
       case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
       case (_, _) => None
     }

    // optimize to not return a value, if none of the partition expressions are reducible
    if (results.forall(p => p.isEmpty)) None else Some(results)
  }

  override def canCreatePartitioning: Boolean =
    SQLConf.get.v2BucketingShuffleEnabled &&
      !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
      partitioning.expressions.forall { e =>
        e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
      }

  override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
    assert(clustering.size == distribution.clustering.size,
      "Required distributions of join legs should be the same size.")

    val newExpressions = partitioning.expressions.zip(keyPositions).map {
      case (te: TransformExpression, positionSet) =>
        te.copy(children = te.children.map(_ => clustering(positionSet.head)))
      case (_, positionSet) => clustering(positionSet.head)
    }
    KeyGroupedPartitioning(newExpressions,
      partitioning.numPartitions,
      partitioning.partitionValues)
  }
}

object KeyGroupedShuffleSpec {
  def reducePartitionValue(
      row: InternalRow,
      expressions: Seq[Expression],
      reducers: Seq[Option[Reducer[_, _]]]):
    InternalRowComparableWrapper = {
    val partitionVals = row.toSeq(expressions.map(_.dataType))
    val reducedRow = partitionVals.zip(reducers).map{
      case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
      case (v, _) => v
    }.toArray
    InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
  }
}

case class ShufflePartitionIdPassThroughSpec(
    partitioning: ShufflePartitionIdPassThrough,
    distribution: ClusteredDistribution) extends ShuffleSpec {

  /**
   * A sequence where each element is a set of positions of the partition key to the cluster
   * keys. Similar to HashShuffleSpec, this maps the partitioning expression to positions
   * in the distribution clustering keys.
   */
  lazy val keyPositions: mutable.BitSet = {
    val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
    distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
      distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
    }
    distKeyToPos.getOrElse(partitioning.expr.child.canonicalized, mutable.BitSet.empty)
  }

  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
    case SinglePartitionShuffleSpec =>
      partitioning.numPartitions == 1
    case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec(
        otherPartitioning, otherDistribution) =>
      // As ShufflePartitionIdPassThrough only allows a single expression
      // as the partitioning expression, we check compatibility as follows:
      // 1. Same number of clustering expressions
      // 2. Same number of partitions
      // 3. each partitioning expression from both sides has overlapping positions in their
      //    corresponding distributions.
      distribution.clustering.length == otherDistribution.clustering.length &&
      partitioning.numPartitions == otherPartitioning.numPartitions && {
        val otherKeyPositions = otherPassThroughSpec.keyPositions
        keyPositions.intersect(otherKeyPositions).nonEmpty
      }
    case ShuffleSpecCollection(specs) =>
      specs.exists(isCompatibleWith)
    case _ =>
      false
  }

  // We don't support creating partitioning for ShufflePartitionIdPassThrough.
  override def canCreatePartitioning: Boolean = false

  override def numPartitions: Int = partitioning.numPartitions
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
  override def isCompatibleWith(other: ShuffleSpec): Boolean = {
    specs.exists(_.isCompatibleWith(other))
  }

  override def canCreatePartitioning: Boolean =
    specs.forall(_.canCreatePartitioning)

  override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
    // as we only consider # of partitions as the cost now, it doesn't matter which one we choose
    // since they should all have the same # of partitions.
    require(specs.map(_.numPartitions).toSet.size == 1, "expected all specs in the collection " +
      "to have the same number of partitions")
    specs.head.createPartitioning(clustering)
  }

  override def numPartitions: Int = {
    require(specs.nonEmpty, "expected specs to be non-empty")
    specs.head.numPartitions
  }
}
