diff --git a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java index f91141c1eb84..b84f6b3b91c0 100644 --- a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java +++ b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java @@ -42,7 +42,8 @@ private boolean isDriver() { public UniffleShuffleManager(SparkConf conf, boolean isDriver) { super(conf, isDriver); - conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssSparkConfig.RSS_ROW_BASED.key(), "false"); + // FIXME: remove this after https://github.com/apache/incubator-uniffle/pull/2193 + conf.set(RssSparkConfig.RSS_ENABLED.key(), "true"); } @Override @@ -69,6 +70,13 @@ public ShuffleWriter getWriter( } else { writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } + // set rss.row.based to false to mark it as columnar shuffle + SparkConf conf = + sparkConf + .clone() + .set( + RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssSparkConfig.RSS_ROW_BASED.key(), + "false"); return new VeloxUniffleColumnarShuffleWriter<>( context.partitionId(), rssHandle.getAppId(), @@ -77,7 +85,7 @@ public ShuffleWriter getWriter( context.taskAttemptId(), writeMetrics, this, - sparkConf, + conf, shuffleWriteClient, rssHandle, this::markFailedTask, diff --git a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java index ca5b3ad9529f..e53605f284b7 100644 --- a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java +++ b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java @@ -41,6 +41,7 @@ import org.apache.spark.util.SparkResourceUtil; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.exception.RssException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -125,9 +126,9 @@ public VeloxUniffleColumnarShuffleWriter( } @Override - protected void writeImpl(Iterator> records) throws IOException { - if (!records.hasNext() && !isMemoryShuffleEnabled) { - super.sendCommit(); + protected void writeImpl(Iterator> records) { + if (!records.hasNext()) { + sendCommit(); return; } // writer already init @@ -185,12 +186,19 @@ public long spill(MemoryTarget self, Spiller.Phase phase, long size) { } } - long startTime = System.nanoTime(); LOG.info("nativeShuffleWriter value {}", nativeShuffleWriter); + // If all of the ColumnarBatch have empty rows, the nativeShuffleWriter still equals -1 if (nativeShuffleWriter == -1L) { - throw new IllegalStateException("nativeShuffleWriter should not be -1L"); + sendCommit(); + return; + } + long startTime = System.nanoTime(); + SplitResult splitResult; + try { + splitResult = jniWrapper.stop(nativeShuffleWriter); + } catch (IOException e) { + throw new RssException(e); } - splitResult = jniWrapper.stop(nativeShuffleWriter); columnarDep .metrics() .get("splitTime") @@ -210,9 +218,7 @@ public long spill(MemoryTarget self, Spiller.Phase phase, long size) { long pushMergedDataTime = System.nanoTime(); // clear all sendRestBlockAndWait(); - if (!isMemoryShuffleEnabled) { - super.sendCommit(); - } + sendCommit(); long writeDurationMs = System.nanoTime() - pushMergedDataTime; shuffleWriteMetrics.incWriteTime(writeDurationMs); LOG.info( @@ -220,6 +226,13 @@ public long spill(MemoryTarget self, Spiller.Phase phase, long size) { TimeUnit.MILLISECONDS.toNanos(writeDurationMs)); } + @Override + protected void sendCommit() { + if (!isMemoryShuffleEnabled) { + super.sendCommit(); + } + } + @Override public Option stop(boolean success) { if (!stopping) {