@@ -23,25 +23,23 @@ import java.net.URI
23
23
import scala .util .control .NonFatal
24
24
25
25
import org .apache .avro .Schema
26
- import org .apache .avro .file .DataFileConstants ._
27
26
import org .apache .avro .file .DataFileReader
28
27
import org .apache .avro .generic .{GenericDatumReader , GenericRecord }
29
- import org .apache .avro .mapred .{AvroOutputFormat , FsInput }
30
- import org .apache .avro .mapreduce .AvroJob
28
+ import org .apache .avro .mapred .FsInput
31
29
import org .apache .hadoop .conf .Configuration
32
30
import org .apache .hadoop .fs .{FileStatus , Path }
33
31
import org .apache .hadoop .mapreduce .Job
34
32
35
- import org .apache .spark .{ SparkException , TaskContext }
33
+ import org .apache .spark .TaskContext
36
34
import org .apache .spark .internal .Logging
37
35
import org .apache .spark .sql .SparkSession
38
36
import org .apache .spark .sql .catalyst .InternalRow
39
37
import org .apache .spark .sql .execution .datasources .{FileFormat , OutputWriterFactory , PartitionedFile }
40
38
import org .apache .spark .sql .sources .{DataSourceRegister , Filter }
41
39
import org .apache .spark .sql .types ._
42
- import org .apache .spark .util .{ SerializableConfiguration , Utils }
40
+ import org .apache .spark .util .SerializableConfiguration
43
41
44
- private [avro ] class AvroFileFormat extends FileFormat
42
+ private [sql ] class AvroFileFormat extends FileFormat
45
43
with DataSourceRegister with Logging with Serializable {
46
44
47
45
override def equals (other : Any ): Boolean = other match {
@@ -56,74 +54,7 @@ private[avro] class AvroFileFormat extends FileFormat
56
54
spark : SparkSession ,
57
55
options : Map [String , String ],
58
56
files : Seq [FileStatus ]): Option [StructType ] = {
59
- val conf = spark.sessionState.newHadoopConf()
60
- if (options.contains(" ignoreExtension" )) {
61
- logWarning(s " Option ${AvroOptions .ignoreExtensionKey} is deprecated. Please use the " +
62
- " general data source option pathGlobFilter for filtering file names." )
63
- }
64
- val parsedOptions = new AvroOptions (options, conf)
65
-
66
- // User can specify an optional avro json schema.
67
- val avroSchema = parsedOptions.schema
68
- .map(new Schema .Parser ().parse)
69
- .getOrElse {
70
- inferAvroSchemaFromFiles(files, conf, parsedOptions.ignoreExtension,
71
- spark.sessionState.conf.ignoreCorruptFiles)
72
- }
73
-
74
- SchemaConverters .toSqlType(avroSchema).dataType match {
75
- case t : StructType => Some (t)
76
- case _ => throw new RuntimeException (
77
- s """ Avro schema cannot be converted to a Spark SQL StructType:
78
- |
79
- | ${avroSchema.toString(true )}
80
- | """ .stripMargin)
81
- }
82
- }
83
-
84
- private def inferAvroSchemaFromFiles (
85
- files : Seq [FileStatus ],
86
- conf : Configuration ,
87
- ignoreExtension : Boolean ,
88
- ignoreCorruptFiles : Boolean ): Schema = {
89
- // Schema evolution is not supported yet. Here we only pick first random readable sample file to
90
- // figure out the schema of the whole dataset.
91
- val avroReader = files.iterator.map { f =>
92
- val path = f.getPath
93
- if (! ignoreExtension && ! path.getName.endsWith(" .avro" )) {
94
- None
95
- } else {
96
- Utils .tryWithResource {
97
- new FsInput (path, conf)
98
- } { in =>
99
- try {
100
- Some (DataFileReader .openReader(in, new GenericDatumReader [GenericRecord ]()))
101
- } catch {
102
- case e : IOException =>
103
- if (ignoreCorruptFiles) {
104
- logWarning(s " Skipped the footer in the corrupted file: $path" , e)
105
- None
106
- } else {
107
- throw new SparkException (s " Could not read file: $path" , e)
108
- }
109
- }
110
- }
111
- }
112
- }.collectFirst {
113
- case Some (reader) => reader
114
- }
115
-
116
- avroReader match {
117
- case Some (reader) =>
118
- try {
119
- reader.getSchema
120
- } finally {
121
- reader.close()
122
- }
123
- case None =>
124
- throw new FileNotFoundException (
125
- " No Avro files found. If files don't have .avro extension, set ignoreExtension to true" )
126
- }
57
+ AvroUtils .inferSchema(spark, options, files)
127
58
}
128
59
129
60
override def shortName (): String = " avro"
@@ -140,32 +71,7 @@ private[avro] class AvroFileFormat extends FileFormat
140
71
job : Job ,
141
72
options : Map [String , String ],
142
73
dataSchema : StructType ): OutputWriterFactory = {
143
- val parsedOptions = new AvroOptions (options, spark.sessionState.newHadoopConf())
144
- val outputAvroSchema : Schema = parsedOptions.schema
145
- .map(new Schema .Parser ().parse)
146
- .getOrElse(SchemaConverters .toAvroType(dataSchema, nullable = false ,
147
- parsedOptions.recordName, parsedOptions.recordNamespace))
148
-
149
- AvroJob .setOutputKeySchema(job, outputAvroSchema)
150
-
151
- if (parsedOptions.compression == " uncompressed" ) {
152
- job.getConfiguration.setBoolean(" mapred.output.compress" , false )
153
- } else {
154
- job.getConfiguration.setBoolean(" mapred.output.compress" , true )
155
- logInfo(s " Compressing Avro output using the ${parsedOptions.compression} codec " )
156
- val codec = parsedOptions.compression match {
157
- case DEFLATE_CODEC =>
158
- val deflateLevel = spark.sessionState.conf.avroDeflateLevel
159
- logInfo(s " Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec. " )
160
- job.getConfiguration.setInt(AvroOutputFormat .DEFLATE_LEVEL_KEY , deflateLevel)
161
- DEFLATE_CODEC
162
- case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC ) => codec
163
- case unknown => throw new IllegalArgumentException (s " Invalid compression codec: $unknown" )
164
- }
165
- job.getConfiguration.set(AvroJob .CONF_OUTPUT_CODEC , codec)
166
- }
167
-
168
- new AvroOutputWriterFactory (dataSchema, outputAvroSchema.toString)
74
+ AvroUtils .prepareWrite(spark.sessionState.conf, job, options, dataSchema)
169
75
}
170
76
171
77
override def buildReader (
@@ -250,22 +156,7 @@ private[avro] class AvroFileFormat extends FileFormat
250
156
}
251
157
}
252
158
253
- override def supportDataType (dataType : DataType ): Boolean = dataType match {
254
- case _ : AtomicType => true
255
-
256
- case st : StructType => st.forall { f => supportDataType(f.dataType) }
257
-
258
- case ArrayType (elementType, _) => supportDataType(elementType)
259
-
260
- case MapType (keyType, valueType, _) =>
261
- supportDataType(keyType) && supportDataType(valueType)
262
-
263
- case udt : UserDefinedType [_] => supportDataType(udt.sqlType)
264
-
265
- case _ : NullType => true
266
-
267
- case _ => false
268
- }
159
+ override def supportDataType (dataType : DataType ): Boolean = AvroUtils .supportsDataType(dataType)
269
160
}
270
161
271
162
private [avro] object AvroFileFormat {
0 commit comments