diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala index 7a6b02237d..a26e3d1796 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala @@ -186,19 +186,32 @@ private[codegen] object CometBatchKernelCodegenOutput { * * Scalars emit `perRow` only. Complex types emit both. Inner setup bubbles up so deep child * casts land at the batch prelude. + * + * `nested` distinguishes the root output vector from a child of a List / Map / Struct. + * `allocateOutput` pre-sizes the root to exactly `numRows` and the kernel writes one scalar per + * row, so the root's fixed-width `set` is always in bounds. A child's element count is instead + * the data-dependent sum of per-row collection sizes, which `numRows` does not bound. We cannot + * pre-size the child either: each row's `ArrayData` / `MapData` is produced by Spark's + * generated `ev.code` inside the write loop, so the total is unknown until we have already + * evaluated every row (counting it first would mean evaluating the tree twice). Nested + * fixed-width writes therefore grow on demand with `setSafe`; the String / Binary / Decimal + * branches already do, for the same reason. */ private def emitWrite( targetVec: String, idx: String, source: String, dataType: DataType, - ctx: CodegenContext): OutputEmit = dataType match { + ctx: CodegenContext, + nested: Boolean = false): OutputEmit = dataType match { case BooleanType => - OutputEmit("", s"$targetVec.set($idx, $source ? 1 : 0);") + val set = if (nested) "setSafe" else "set" + OutputEmit("", s"$targetVec.$set($idx, $source ? 1 : 0);") case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType | TimestampType | TimestampNTZType => // Spark codegen emits the matching primitive Java type; Arrow `set` overloads accept it. - OutputEmit("", s"$targetVec.set($idx, $source);") + val set = if (nested) "setSafe" else "set" + OutputEmit("", s"$targetVec.$set($idx, $source);") case dt: DecimalType => // DecimalOutputShortFastPath: precision <= 18 fits in a signed long, so pass the unscaled // value to `setSafe(int, long)` and skip the BigDecimal allocation. @@ -250,7 +263,8 @@ private[codegen] object CometBatchKernelCodegenOutput { val childIdx = ctx.freshName("cidx") val jVar = ctx.freshName("j") val elemSource = emitSpecializedGetterExpr(arrVar, jVar, elementType) - val inner = emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx) + val inner = + emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx, nested = true) val setup = (s"$childClass $childVar = ($childClass) $targetVec.getDataVector();" +: Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") @@ -285,7 +299,11 @@ private[codegen] object CometBatchKernelCodegenOutput { val childDecl = s"$childClass $childVar = ($childClass) $targetVec.getChildByOrdinal($fi);" val fieldSource = emitSpecializedGetterExpr(rowVar, fi.toString, field.dataType) - val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx) + // Struct fields are co-indexed with the struct (written at the same `idx`), so a field is + // nested exactly when the struct is: top-level struct fields land at the row index and are + // pre-sized to numRows (bare `set` is in bounds); a struct nested in an array/map inherits + // that parent's cumulative, unbounded index and needs `setSafe`. + val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx, nested = nested) val write = if (!field.nullable) { inner.perRow @@ -327,8 +345,10 @@ private[codegen] object CometBatchKernelCodegenOutput { val valClass = outputVectorClass(mt.valueType) val keySrcExpr = emitSpecializedGetterExpr(keyArr, jVar, mt.keyType) val valSrcExpr = emitSpecializedGetterExpr(valArr, jVar, mt.valueType) - val keyEmit = emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx) - val valEmit = emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx) + val keyEmit = + emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx, nested = true) + val valEmit = + emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx, nested = true) val setup = (Seq( s"$structClass $entriesVar = ($structClass) $targetVec.getDataVector();", diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala b/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala index 13334a5134..bce8bfc598 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala @@ -20,9 +20,12 @@ package org.apache.comet import org.apache.arrow.vector.ValueVector +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +import org.apache.comet.codegen.CometBatchKernelCodegen import org.apache.comet.udf.codegen.CometScalaUDFCodegen +import org.apache.comet.vector.CometVector /** * Shared assertions for the codegen-dispatcher test suites. Mix in alongside `CometTestBase`. @@ -79,4 +82,27 @@ trait CometCodegenAssertions { s"expected kernel signature $expectedNames -> $output; " + s"cache had ${sigs.map { case (c, d) => (c.map(_.getSimpleName), d) }}") } + + /** + * Compiles `expr` (no input columns), runs one batch of `numRows`, and hands the output + * `CometVector` to `read`. Every row evaluates to the same value (the expression has no input), + * which still exercises the cross-row cumulative child index of the collection output writer: + * the child of a List / Map grows by each row's element count, so a batch of N rows drives the + * accumulation that a single row cannot. Drives the writer directly, without a query plan, so + * it reaches complex-output expressions the serde does not route through dispatch today. The + * vector is closed after `read` returns, so `read` must fully materialize what it needs. + */ + protected def runKernel[T](expr: Expression, numRows: Int)(read: CometVector => T): T = { + val kernel = CometBatchKernelCodegen.compile(expr, IndexedSeq.empty).newInstance() + val field = CometBatchKernelCodegen.toFfiArrowField("out", expr.dataType, nullable = true) + val out = CometBatchKernelCodegen.allocateOutput(field, numRows, 0) + try { + kernel.init(0) + kernel.process(Array.empty[ValueVector], out, numRows) + out.setValueCount(numRows) + read(CometVector.getVector(out, null)) + } finally { + out.close() + } + } } diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala index 9167713cad..5c40dccaa3 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala @@ -27,12 +27,18 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.spark.SparkConf import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.DataTypeSupport.isComplexType +import org.apache.comet.codegen.CometBatchKernelCodegen import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions} +import org.apache.comet.vector.CometVector /** * Randomized end-to-end tests for the Arrow-direct codegen dispatcher: schema-driven coverage of @@ -406,4 +412,108 @@ class CometCodegenFuzzSuite } } } + + /** + * Randomized output-writer coverage (#4539). Generates a random nested output type and a random + * catalyst value of that type, wraps it in a `Literal`, and drives it through the kernel output + * writer with [[runKernel]]. Reading the Arrow output back must reproduce the value. + * + * Random Array / Map sizes mean each collection's child vector fills at a cumulative index that + * `numRows` does not bound, so the writer must grow the child with `setSafe` (the #4539 fix). A + * multi-row batch additionally exercises the cumulative index across rows. The root is always a + * collection so the nested-write path always runs. The generated value is its own oracle: + * `CatalystTypeConverters.convertToScala` materializes both the value and the Arrow readback + * (both expose the catalyst ArrayData / MapData / InternalRow interface) and the two must + * compare equal. + */ + private val outputLeafTypes: Seq[DataType] = + Seq(IntegerType, LongType, DoubleType, BooleanType, StringType, DecimalType(10, 2)) + + private def randomLeafType(r: Random): DataType = + outputLeafTypes(r.nextInt(outputLeafTypes.size)) + + /** Random nested type, biased toward leaves as depth runs out. Map keys are always leaves. */ + private def randomOutputType(r: Random, depth: Int): DataType = + if (depth <= 0 || r.nextDouble() < 0.4) randomLeafType(r) + else + r.nextInt(3) match { + case 0 => ArrayType(randomOutputType(r, depth - 1), containsNull = true) + case 1 => + MapType(randomLeafType(r), randomOutputType(r, depth - 1), valueContainsNull = true) + case _ => + StructType((0 to r.nextInt(2)).map(i => + StructField(s"f$i", randomOutputType(r, depth - 1), nullable = true))) + } + + private def randomLeafValue(r: Random, dt: DataType): Any = dt match { + case IntegerType => r.nextInt() + case LongType => r.nextLong() + case DoubleType => r.nextDouble() + case BooleanType => r.nextBoolean() + case StringType => UTF8String.fromString(s"s${r.nextInt(1000000)}") + case d: DecimalType => Decimal((r.nextInt(2000000) - 1000000).toLong, d.precision, d.scale) + case other => throw new IllegalArgumentException(s"unexpected leaf type $other") + } + + /** Random catalyst value of `dt`; `nullable` permits an occasional null element / field. */ + private def randomOutputValue(r: Random, dt: DataType, nullable: Boolean): Any = { + if (nullable && r.nextDouble() < 0.2) null + else + dt match { + case ArrayType(e, containsNull) => + val n = r.nextInt(40) + new GenericArrayData( + (0 until n).map(_ => randomOutputValue(r, e, containsNull)).toArray[Any]) + case MapType(k, v, valueContainsNull) => + // Dedup by materialized key so the map round-trips 1:1 (Spark map keys are distinct). + val entries = scala.collection.mutable.LinkedHashMap.empty[Any, Any] + (0 until r.nextInt(20)).foreach { _ => + val key = randomOutputValue(r, k, nullable = false) + entries.getOrElseUpdate(key, randomOutputValue(r, v, valueContainsNull)) + } + new ArrayBasedMapData( + new GenericArrayData(entries.keys.toArray[Any]), + new GenericArrayData(entries.values.toArray[Any])) + case st: StructType => + new GenericInternalRow( + st.fields.map(f => randomOutputValue(r, f.dataType, f.nullable)).toArray[Any]) + case leaf => randomLeafValue(r, leaf) + } + } + + /** Reads the root collection value of `vec` at `row` as a catalyst ArrayData / MapData. */ + private def readRoot(vec: CometVector, dt: DataType, row: Int): Any = dt match { + case _: ArrayType => vec.getArray(row) + case _: MapType => vec.getMap(row) + case other => throw new IllegalArgumentException(s"unexpected root type $other") + } + + test("randomized dynamically-sized collection output round-trips through the writer (#4539)") { + val r = new Random(42) + val numRows = 4 // > 1 so the child's cumulative index accumulates across rows + // canHandle may reject a generated type (e.g. the maxFields gate on a wide nesting); count + // the ones we actually drove through the writer to guard against a vacuous run. + val exercised = (0 until 300).count { _ => + // Root is always a collection so the nested-child write path runs every iteration. + val dt = + if (r.nextBoolean()) ArrayType(randomOutputType(r, 2), containsNull = true) + else MapType(randomLeafType(r), randomOutputType(r, 2), valueContainsNull = true) + val value = randomOutputValue(r, dt, nullable = false) + val expr = Literal(value, dt) + val handled = CometBatchKernelCodegen.canHandle(expr).isEmpty + if (handled) { + val expected = CatalystTypeConverters.convertToScala(value, dt) + runKernel(expr, numRows) { vec => + (0 until numRows).foreach { row => + val actual = CatalystTypeConverters.convertToScala(readRoot(vec, dt, row), dt) + assert( + actual === expected, + s"row $row mismatch for output type $dt\n expected=$expected\n actual=$actual") + } + } + } + handled + } + assert(exercised > 0, "every generated type was rejected by canHandle (of 300 generated)") + } } diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala index e8f59e8b22..a609ccbfe2 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -450,6 +450,46 @@ class CometCodegenSourceSuite extends AnyFunSuite { } } + test("nested fixed-width map children grow with setSafe, not set (#4539)") { + // Map output: both key and value are fixed-width children of the entries struct. + // Their element count is the data-dependent sum of per-row map sizes, not bounded by numRows, + // and is unknown until the write loop has evaluated each row, so the writes must use `setSafe` + // to grow on demand. A bare `set` throws once a row's entries exceed the child's initial + // capacity (issue #4539: the literal map's third key overflowed the pre-sized IntVector). + val expr = CreateMap( + Seq( + Literal(1, IntegerType), + Literal(10, IntegerType), + Literal(2, IntegerType), + Literal(20, IntegerType))) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body + assert( + src.contains(".setSafe("), + s"expected setSafe for nested fixed-width writes; got:\n$src") + // `.set(` is a bare fixed-width write; `setSafe(` / `setNull(` / `setIndexDefined(` do not + // match this literal. There must be none into the nested children. + assert( + !src.contains(".set("), + s"expected no bare fixed-width set into map children; got:\n$src") + } + + test("top-level scalar output keeps the pre-sized set fast path") { + // The root output vector is pre-sized to numRows and written once per row, so it uses the + // bare `set` fast path rather than paying for setSafe's per-write capacity check. This pins + // the boundary the #4539 fix draws: setSafe is for nested children only. + val expr = Add(BoundReference(0, IntegerType, nullable = false), Literal(1, IntegerType)) + val intSpec = ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(intSpec)).body + assert( + src.contains("output.set("), + s"expected bare set for the pre-sized root output; got:\n$src") + assert( + !src.contains(".setSafe("), + s"expected no setSafe for a scalar root output; got:\n$src") + } + test("ArrayType output elides isNullAt on the element loop when containsNull is false") { // CreateArray over only-non-null Literals produces ArrayType(elementType, containsNull=false). // The element write should drop the `arr.isNullAt(j)` guard at source level rather than diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala index cdae68c90a..76ffbc8c99 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala @@ -19,12 +19,16 @@ package org.apache.comet +import scala.util.Random + import org.apache.arrow.vector._ import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.api.java.UDF1 +import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap, CreateNamedStruct, Expression, Literal, MapConcat} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.udf.codegen.CometScalaUDFCodegen @@ -1021,6 +1025,77 @@ class CometCodegenSuite } } + private def kernelMapIntString(expr: Expression): Map[Int, String] = + runKernel(expr, 1) { v => + val map = v.getMap(0) + val keys = map.keyArray() + val values = map.valueArray() + (0 until map.numElements()) + .map(i => keys.getInt(i) -> values.getUTF8String(i).toString) + .toMap + } + + test("constant-folded map_concat output round-trips every key through the kernel (#4539)") { + // map_concat(map(1,'a',2,'b'), map(3,'c')) is all-literal, so Spark's optimizer constant-folds + // it to a Literal(MapType) holding an ArrayBasedMapData. The MapType output writer must marshal + // every entry into the Arrow MapVector; the reported bug corrupts the last key (3 -> 0). + def s(str: String): Literal = Literal(UTF8String.fromString(str), StringType) + val map1 = + CreateMap(Seq(Literal(1), s("a"), Literal(2), s("b")), useStringTypeWhenEmpty = false) + val map2 = CreateMap(Seq(Literal(3), s("c")), useStringTypeWhenEmpty = false) + val folded = + Literal.create(MapConcat(Seq(map1, map2)).eval(null), MapType(IntegerType, StringType)) + assert(kernelMapIntString(folded) === Map(1 -> "a", 2 -> "b", 3 -> "c")) + } + + test("constant-folded array output writes every element past the pre-sized child (#4539)") { + // A single-row array with far more elements than the list child's numRows-derived initial + // capacity. The element child is written at a cumulative index, so a bare `set` overflows the + // pre-sized buffer once the row's element count exceeds it; `setSafe` grows it. Sibling of the + // map_concat case for ArrayType. + val n = 16 + val elems = (0 until n).map(i => Literal(i * 10, IntegerType)) + val folded = + Literal.create(CreateArray(elems).eval(null), ArrayType(IntegerType, containsNull = false)) + + val got = runKernel(folded, 1) { v => + val arr = v.getArray(0) + (0 until arr.numElements()).map(arr.getInt) + } + assert(got === (0 until n).map(_ * 10)) + } + + test( + "constant-folded Array> writes struct fields past the pre-sized child " + + "(#4539)") { + // The struct sits inside an array, so its fields inherit the array's cumulative index. The + // fixed-width Int field would overflow with a bare `set`; propagating `nested` into the struct + // branch makes it `setSafe`. Guards the struct-nested-in-collection path. + val n = 16 + def structAt(i: Int): Expression = + CreateNamedStruct( + Seq( + Literal("a"), + Literal(i, IntegerType), + Literal("b"), + Literal(UTF8String.fromString(s"v$i"), StringType))) + val structType = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", StringType, nullable = false) + val folded = Literal.create( + CreateArray((0 until n).map(structAt)).eval(null), + ArrayType(structType, containsNull = false)) + + val got = runKernel(folded, 1) { v => + val arr = v.getArray(0) + (0 until arr.numElements()).map { i => + val r = arr.getStruct(i, 2) + r.getInt(0) -> r.getUTF8String(1).toString + } + } + assert(got === (0 until n).map(i => i -> s"v$i")) + } + test("array_distinct on Array> retains element identity across hash set") { // Fuzz signal: cardinality(array_distinct(arr_of_struct)) returns 1 where Spark returns 2. // Hypothesis: the kernel's InputStruct wrapper backing array_distinct's element reads is @@ -1153,6 +1228,174 @@ class CometCodegenSuite // Runtime coverage for nullable nested `getStruct` / `getArray` / `getMap` element reads is // exercised through HOFs in `CometCodegenHOFSuite`. Static emitter assertions live in // `CometCodegenSourceSuite`. + + /** + * Dynamically sized collection output (regression family for #4539). Each UDF takes a scalar + * seed and returns a collection whose per-row size is a function of the seed, so the output + * writer fills each collection's child vector at a cumulative index that `numRows` does not + * bound. Before #4539 the fixed-width child writes used a bare `set`, which ran off the end of + * the pre-sized buffer; with Comet's unsafe Arrow memory the overflow corrupted neighboring + * entries (or, under NMT, aborted the JVM). Scalar input keeps the read side off the + * complex-input deserializer, isolating coverage to the writer. + * + * A small batch size makes the child's `numRows`-derived pre-size tiny relative to the per-row + * element counts, so the larger rows reliably push past it. Randomized type/shape coverage of + * the same writer lives in `CometCodegenFuzzSuite`. + */ + private val collectionOutputSeeds: Seq[String] = { + val rng = new Random(42) + (0 until 256).map { i => + if (i % 17 == 0) "NULL" // null result + else if (i % 13 == 0) "0" // empty collection + else (rng.nextInt(80) - 39).toString // mix of small and larger-than-batch sizes + } + } + + private def withSeedTable(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (seed INT) USING parquet") + collectionOutputSeeds.grouped(64).foreach { batch => + sql(s"INSERT INTO t VALUES ${batch.map(s => s"($s)").mkString(", ")}") + } + f + } + } + + private case class CollectionOutputCase(label: String, register: () => String) + + private val collectionOutputCases: Seq[CollectionOutputCase] = Seq( + // Fixed-width element with nulls: the exact nested write #4539 corrupted. + CollectionOutputCase( + "Array with null elements", + () => { + val n = "arrout_int" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else + (0 until (math.abs(i.intValue) % 40)).map(j => + if (j % 4 == 0) null else java.lang.Integer.valueOf(i + j))) + n + }), + CollectionOutputCase( + "Array", + () => { + val n = "arrout_long" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else (0 until (math.abs(i.intValue) % 40)).map(j => (i.toLong + j) * 1000000000L)) + n + }), + CollectionOutputCase( + "Array with null elements", + () => { + val n = "arrout_str" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else + (0 until (math.abs(i.intValue) % 40)).map(j => + if (j % 3 == 0) null else s"v${i}_$j")) + n + }), + CollectionOutputCase( + "Array", + () => { + val n = "arrout_dec" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else + (0 until (math.abs(i.intValue) % 40)).map(j => + java.math.BigDecimal.valueOf((i + j).toLong))) + n + }), + CollectionOutputCase( + "Array", + () => { + val n = "arrout_bin" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else + (0 until (math.abs(i.intValue) % 40)).map(j => + if (j % 5 == 0) null else s"b${i}_$j".getBytes("UTF-8"))) + n + }), + CollectionOutputCase( + "Map", + () => { + val n = "mapout_ii" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else (0 until (math.abs(i.intValue) % 40)).map(j => j -> (i + j)).toMap) + n + }), + CollectionOutputCase( + "Map", + () => { + val n = "mapout_si" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else (0 until (math.abs(i.intValue) % 40)).map(j => s"k$j" -> (i + j)).toMap) + n + }), + CollectionOutputCase( + "Array>", + () => { + val n = "arrout_arr" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else (0 until (math.abs(i.intValue) % 40)).map(j => (0 to j).map(_ + i))) + n + }), + CollectionOutputCase( + "Map>", + () => { + val n = "mapout_iarr" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else (0 until (math.abs(i.intValue) % 40)).map(j => j -> (0 to j).map(_ + i)).toMap) + n + }), + CollectionOutputCase( + "Array>", + () => { + val n = "arrout_struct" + spark.udf.register( + n, + (i: java.lang.Integer) => + if (i == null) null + else (0 until (math.abs(i.intValue) % 40)).map(j => IntStr(i + j, s"v$j"))) + n + })) + + for (c <- collectionOutputCases) { + test(s"dynamically-sized ${c.label} output round-trips through codegen dispatch (#4539)") { + val udf = c.register() + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "8") { + withSeedTable { + assertCodegenRan { + checkSparkAnswerAndOperator(sql(s"SELECT $udf(seed) FROM t")) + } + } + } + } + } } /** @@ -1165,3 +1408,6 @@ private case class NameAgePair(name: String, age: Int) private case class NameItems(name: String, items: Seq[Int]) private case class XyPair(x: Int, y: String) + +/** Element type for the `Array>` dynamically-sized output case. */ +private case class IntStr(a: Int, b: String)