Skip to content

Commit 9254492

Browse files
WeichenXu123mengxr
authored andcommitted
[SPARK-22666][ML][SQL] Spark datasource for image format
## What changes were proposed in this pull request? Implement an image schema datasource. This image datasource support: - partition discovery (loading partitioned images) - dropImageFailures (the same behavior with `ImageSchema.readImage`) - path wildcard matching (the same behavior with `ImageSchema.readImage`) - loading recursively from directory (different from `ImageSchema.readImage`, but use such path: `/path/to/dir/**`) This datasource **NOT** support: - specify `numPartitions` (it will be determined by datasource automatically) - sampling (you can use `df.sample` later but the sampling operator won't be pushdown to datasource) ## How was this patch tested? Unit tests. ## Benchmark I benchmark and compare the cost time between old `ImageSchema.read` API and my image datasource. **cluster**: 4 nodes, each with 64GB memory, 8 cores CPU **test dataset**: Flickr8k_Dataset (about 8091 images) **time cost**: - My image datasource time (automatically generate 258 partitions): 38.04s - `ImageSchema.read` time (set 16 partitions): 68.4s - `ImageSchema.read` time (set 258 partitions): 90.6s **time cost when increase image number by double (clone Flickr8k_Dataset and loads double number images)**: - My image datasource time (automatically generate 515 partitions): 95.4s - `ImageSchema.read` (set 32 partitions): 109s - `ImageSchema.read` (set 515 partitions): 105s So we can see that my image datasource implementation (this PR) bring some performance improvement compared against old`ImageSchema.read` API. Closes apache#22328 from WeichenXu123/image_datasource. Authored-by: WeichenXu <[email protected]> Signed-off-by: Xiangrui Meng <[email protected]>
1 parent c66eef8 commit 9254492

File tree

27 files changed

+323
-4
lines changed

27 files changed

+323
-4
lines changed
File renamed without changes.

data/mllib/images/origin/license.txt

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved:
2+
https://creativecommons.org/share-your-work/public-domain/cc0/
3+
The images are taken from:
4+
https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q==
5+
https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA==
6+
https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ==
7+
https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw==
8+
9+
The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from:
10+
https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw==
11+
12+
The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from:
13+
https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
not an image
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
org.apache.spark.ml.source.libsvm.LibSVMFileFormat
2+
org.apache.spark.ml.source.image.ImageFileFormat
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source.image
19+
20+
/**
21+
* `image` package implements Spark SQL data source API for loading image data as `DataFrame`.
22+
* The loaded `DataFrame` has one `StructType` column: `image`.
23+
* The schema of the `image` column is:
24+
* - origin: String (represents the file path of the image)
25+
* - height: Int (height of the image)
26+
* - width: Int (width of the image)
27+
* - nChannels: Int (number of the image channels)
28+
* - mode: Int (OpenCV-compatible type)
29+
* - data: BinaryType (Image bytes in OpenCV-compatible order: row-wise BGR in most cases)
30+
*
31+
* To use image data source, you need to set "image" as the format in `DataFrameReader` and
32+
* optionally specify the data source options, for example:
33+
* {{{
34+
* // Scala
35+
* val df = spark.read.format("image")
36+
* .option("dropInvalid", true)
37+
* .load("data/mllib/images/partitioned")
38+
*
39+
* // Java
40+
* Dataset<Row> df = spark.read().format("image")
41+
* .option("dropInvalid", true)
42+
* .load("data/mllib/images/partitioned");
43+
* }}}
44+
*
45+
* Image data source supports the following options:
46+
* - "dropInvalid": Whether to drop the files that are not valid images from the result.
47+
*
48+
* @note This IMAGE data source does not support saving images to files.
49+
*
50+
* @note This class is public for documentation purpose. Please don't use this class directly.
51+
* Rather, use the data source API as illustrated above.
52+
*/
53+
class ImageDataSource private() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source.image
19+
20+
import com.google.common.io.{ByteStreams, Closeables}
21+
import org.apache.hadoop.conf.Configuration
22+
import org.apache.hadoop.fs.{FileStatus, Path}
23+
import org.apache.hadoop.mapreduce.Job
24+
25+
import org.apache.spark.ml.image.ImageSchema
26+
import org.apache.spark.sql.SparkSession
27+
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
29+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeRow}
30+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
31+
import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, OutputWriterFactory, PartitionedFile}
32+
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
33+
import org.apache.spark.sql.types.StructType
34+
import org.apache.spark.util.SerializableConfiguration
35+
36+
private[image] class ImageFileFormat extends FileFormat with DataSourceRegister {
37+
38+
override def inferSchema(
39+
sparkSession: SparkSession,
40+
options: Map[String, String],
41+
files: Seq[FileStatus]): Option[StructType] = Some(ImageSchema.imageSchema)
42+
43+
override def prepareWrite(
44+
sparkSession: SparkSession,
45+
job: Job,
46+
options: Map[String, String],
47+
dataSchema: StructType): OutputWriterFactory = {
48+
throw new UnsupportedOperationException("Write is not supported for image data source")
49+
}
50+
51+
override def shortName(): String = "image"
52+
53+
override protected def buildReader(
54+
sparkSession: SparkSession,
55+
dataSchema: StructType,
56+
partitionSchema: StructType,
57+
requiredSchema: StructType,
58+
filters: Seq[Filter],
59+
options: Map[String, String],
60+
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
61+
assert(
62+
requiredSchema.length <= 1,
63+
"Image data source only produces a single data column named \"image\".")
64+
65+
val broadcastedHadoopConf =
66+
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
67+
68+
val imageSourceOptions = new ImageOptions(options)
69+
70+
(file: PartitionedFile) => {
71+
val emptyUnsafeRow = new UnsafeRow(0)
72+
if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) {
73+
Iterator(emptyUnsafeRow)
74+
} else {
75+
val origin = file.filePath
76+
val path = new Path(origin)
77+
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
78+
val stream = fs.open(path)
79+
val bytes = try {
80+
ByteStreams.toByteArray(stream)
81+
} finally {
82+
Closeables.close(stream, true)
83+
}
84+
val resultOpt = ImageSchema.decode(origin, bytes)
85+
val filteredResult = if (imageSourceOptions.dropInvalid) {
86+
resultOpt.toIterator
87+
} else {
88+
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
89+
}
90+
91+
if (requiredSchema.isEmpty) {
92+
filteredResult.map(_ => emptyUnsafeRow)
93+
} else {
94+
val converter = RowEncoder(requiredSchema)
95+
filteredResult.map(row => converter.toRow(row))
96+
}
97+
}
98+
}
99+
}
100+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source.image
19+
20+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
21+
22+
private[image] class ImageOptions(
23+
@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
24+
25+
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
26+
27+
/**
28+
* Whether to drop invalid images. If true, invalid images will be removed, otherwise
29+
* invalid images will be returned with empty data and all other field filled with `-1`.
30+
*/
31+
val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
32+
}

mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
2828

2929
class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
3030
// Single column of images named "image"
31-
private lazy val imagePath = "../data/mllib/images"
31+
private lazy val imagePath = "../data/mllib/images/origin"
3232

3333
test("Smoke test: create basic ImageSchema dataframe") {
3434
val origin = "path"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source.image
19+
20+
import java.nio.file.Paths
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.ml.image.ImageSchema._
24+
import org.apache.spark.mllib.util.MLlibTestSparkContext
25+
import org.apache.spark.sql.Row
26+
import org.apache.spark.sql.functions.{col, substring_index}
27+
28+
class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {
29+
30+
// Single column of images named "image"
31+
private lazy val imagePath = "../data/mllib/images/partitioned"
32+
33+
test("image datasource count test") {
34+
val df1 = spark.read.format("image").load(imagePath)
35+
assert(df1.count === 9)
36+
37+
val df2 = spark.read.format("image").option("dropInvalid", true).load(imagePath)
38+
assert(df2.count === 8)
39+
}
40+
41+
test("image datasource test: read jpg image") {
42+
val df = spark.read.format("image").load(imagePath + "/cls=kittens/date=2018-02/DP153539.jpg")
43+
assert(df.count() === 1)
44+
}
45+
46+
test("image datasource test: read png image") {
47+
val df = spark.read.format("image").load(imagePath + "/cls=multichannel/date=2018-01/BGRA.png")
48+
assert(df.count() === 1)
49+
}
50+
51+
test("image datasource test: read non image") {
52+
val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt"
53+
val df = spark.read.format("image").option("dropInvalid", true)
54+
.load(filePath)
55+
assert(df.count() === 0)
56+
57+
val df2 = spark.read.format("image").option("dropInvalid", false)
58+
.load(filePath)
59+
assert(df2.count() === 1)
60+
val result = df2.head()
61+
assert(result === invalidImageRow(
62+
Paths.get(filePath).toAbsolutePath().normalize().toUri().toString))
63+
}
64+
65+
test("image datasource partition test") {
66+
val result = spark.read.format("image")
67+
.option("dropInvalid", true).load(imagePath)
68+
.select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date"))
69+
.collect()
70+
71+
assert(Set(result: _*) === Set(
72+
Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"),
73+
Row("54893.jpg", "kittens", "2018-02"),
74+
Row("DP153539.jpg", "kittens", "2018-02"),
75+
Row("DP802813.jpg", "kittens", "2018-02"),
76+
Row("BGRA.png", "multichannel", "2018-01"),
77+
Row("BGRA_alpha_60.png", "multichannel", "2018-01"),
78+
Row("chr30.4.184.jpg", "multichannel", "2018-02"),
79+
Row("grayscale.jpg", "multichannel", "2018-02")
80+
))
81+
}
82+
83+
// Images with the different number of channels
84+
test("readImages pixel values test") {
85+
val images = spark.read.format("image").option("dropInvalid", true)
86+
.load(imagePath + "/cls=multichannel/").collect()
87+
88+
val firstBytes20Set = images.map { rrow =>
89+
val row = rrow.getAs[Row]("image")
90+
val filename = Paths.get(getOrigin(row)).getFileName().toString()
91+
val mode = getMode(row)
92+
val bytes20 = getData(row).slice(0, 20).toList
93+
filename -> Tuple2(mode, bytes20) // Cannot remove `Tuple2`, otherwise `->` operator
94+
// will match 2 arguments
95+
}.toSet
96+
97+
assert(firstBytes20Set === expectedFirstBytes20Set)
98+
}
99+
100+
// number of channels and first 20 bytes of OpenCV representation
101+
// - default representation for 3-channel RGB images is BGR row-wise:
102+
// (B00, G00, R00, B10, G10, R10, ...)
103+
// - default representation for 4-channel RGB images is BGRA row-wise:
104+
// (B00, G00, R00, A00, B10, G10, R10, A10, ...)
105+
private val expectedFirstBytes20Set = Set(
106+
"grayscale.jpg" ->
107+
((0, List[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62,
108+
-57, -60, -63, -53, -49, -55, -69))),
109+
"chr30.4.184.jpg" -> ((16,
110+
List[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57,
111+
-71, -58, -56, -73, -64))),
112+
"BGRA.png" -> ((24,
113+
List[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128,
114+
-128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))),
115+
"BGRA_alpha_60.png" -> ((24,
116+
List[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128,
117+
-128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60)))
118+
)
119+
}

python/pyspark/ml/image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def readImages(self, path, recursive=False, numPartitions=-1,
216216
:return: a :class:`DataFrame` with a single column of "images",
217217
see ImageSchema for details.
218218
219-
>>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True)
219+
>>> df = ImageSchema.readImages('data/mllib/images/origin/kittens', recursive=True)
220220
>>> df.count()
221221
5
222222

python/pyspark/ml/tests.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,7 @@ def tearDown(self):
21862186
class ImageReaderTest(SparkSessionTestCase):
21872187

21882188
def test_read_images(self):
2189-
data_path = 'data/mllib/images/kittens'
2189+
data_path = 'data/mllib/images/origin/kittens'
21902190
df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
21912191
self.assertEqual(df.count(), 4)
21922192
first_row = df.take(1)[0][0]
@@ -2253,7 +2253,7 @@ def tearDownClass(cls):
22532253
def test_read_images_multiple_times(self):
22542254
# This test case is to check if `ImageSchema.readImages` tries to
22552255
# initiate Hive client multiple times. See SPARK-22651.
2256-
data_path = 'data/mllib/images/kittens'
2256+
data_path = 'data/mllib/images/origin/kittens'
22572257
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
22582258
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
22592259

0 commit comments

Comments
 (0)