diff --git a/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift b/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift index 903b4078c39b3..8ef08d222b968 100644 --- a/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift +++ b/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift @@ -85,13 +85,15 @@ struct CxxSpan: ParamInfo { ) -> BoundsCheckedThunkBuilder { switch pointerIndex { case .param(let i): - return CxxSpanThunkBuilder(base: base, index: i - 1, signature: funcDecl.signature, + return CxxSpanThunkBuilder( + base: base, index: i - 1, signature: funcDecl.signature, typeMappings: typeMappings, node: original, nonescaping: nonescaping) case .return: if dependencies.isEmpty { return base } - return CxxSpanReturnThunkBuilder(base: base, signature: funcDecl.signature, + return CxxSpanReturnThunkBuilder( + base: base, signature: funcDecl.signature, typeMappings: typeMappings, node: original) case .self: return base @@ -121,7 +123,7 @@ struct CountedBy: ParamInfo { switch pointerIndex { case .param(let i): return CountedOrSizedPointerThunkBuilder( - base: base, index: i-1, countExpr: count, + base: base, index: i - 1, countExpr: count, signature: funcDecl.signature, nonescaping: nonescaping, isSizedBy: sizedBy, skipTrivialCount: skipTrivialCount) case .return: @@ -263,9 +265,9 @@ func isRawPointerType(text: String) -> Bool { // Remove std. or std.__1. prefix func getUnqualifiedStdName(_ type: String) -> String? { - if (type.hasPrefix("std.")) { + if type.hasPrefix("std.") { var ty = type.dropFirst(4) - if (ty.hasPrefix("__1.")) { + if ty.hasPrefix("__1.") { ty = ty.dropFirst(4) } return String(ty) @@ -306,10 +308,14 @@ func hasOwnershipSpecifier(_ attrType: AttributedTypeSyntax) -> Bool { }) } -func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool, _ setMutableSpanInout: Bool) throws -> TypeSyntax { +func transformType( + _ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool, _ setMutableSpanInout: Bool +) throws -> TypeSyntax { if let optType = prev.as(OptionalTypeSyntax.self) { return TypeSyntax( - optType.with(\.wrappedType, try transformType(optType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout))) + optType.with( + \.wrappedType, + try transformType(optType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout))) } if let impOptType = prev.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) { return try transformType(impOptType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout) @@ -318,7 +324,9 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool, // We insert 'inout' by default for MutableSpan, but it shouldn't override existing ownership let setMutableSpanInoutNext = setMutableSpanInout && !hasOwnershipSpecifier(attrType) return TypeSyntax( - attrType.with(\.baseType, try transformType(attrType.baseType, generateSpan, isSizedBy, setMutableSpanInoutNext))) + attrType.with( + \.baseType, + try transformType(attrType.baseType, generateSpan, isSizedBy, setMutableSpanInoutNext))) } let name = try getTypeName(prev) let text = name.text @@ -336,11 +344,12 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool, + " - first type token is '\(text)'", node: name) } let token = getSafePointerName(mut: kind, generateSpan: generateSpan, isRaw: isSizedBy) - let mainType = if isSizedBy { - TypeSyntax(IdentifierTypeSyntax(name: token)) - } else { - try replaceTypeName(prev, token) - } + let mainType = + if isSizedBy { + TypeSyntax(IdentifierTypeSyntax(name: token)) + } else { + try replaceTypeName(prev, token) + } if setMutableSpanInout && generateSpan && kind == .Mutable { return TypeSyntax("inout \(mainType)") } @@ -403,7 +412,8 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder { } func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws - -> (FunctionSignatureSyntax, Bool) { + -> (FunctionSignatureSyntax, Bool) + { var newParams = base.signature.parameterClause.parameters.enumerated().filter { let type = argTypes[$0.offset] // filter out deleted parameters, i.e. ones where argTypes[i] _contains_ nil @@ -415,7 +425,8 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder { newParams.append(last.with(\.trailingComma, nil)) } - var sig = base.signature.with(\.parameterClause.parameters, FunctionParameterListSyntax(newParams)) + var sig = base.signature.with( + \.parameterClause.parameters, FunctionParameterListSyntax(newParams)) if returnType != nil { sig = sig.with(\.returnClause!.type, returnType!) } @@ -469,7 +480,8 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder { } func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws - -> (FunctionSignatureSyntax, Bool) { + -> (FunctionSignatureSyntax, Bool) + { var types = argTypes types[index] = try newType return try base.buildFunctionSignature(types, returnType) @@ -518,7 +530,8 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder { } func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws - -> (FunctionSignatureSyntax, Bool) { + -> (FunctionSignatureSyntax, Bool) + { assert(returnType == nil) return try base.buildFunctionSignature(argTypes, newType) } @@ -526,11 +539,12 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder { func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax { let call = try base.buildFunctionCall(pointerArgs) let (_, isConst) = dropCxxQualifiers(try genericArg) - let cast = if isConst { - "Span" - } else { - "MutableSpan" - } + let cast = + if isConst { + "Span" + } else { + "MutableSpan" + } return "unsafe _cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())" } } @@ -584,11 +598,12 @@ extension SpanBoundsThunkBuilder { var newType: TypeSyntax { get throws { let (strippedArg, isConst) = dropCxxQualifiers(try genericArg) - let mutablePrefix = if isConst { - "" - } else { - "Mutable" - } + let mutablePrefix = + if isConst { + "" + } else { + "Mutable" + } let mainType = replaceBaseType( oldType, TypeSyntax("\(raw: mutablePrefix)Span<\(raw: strippedArg)>")) @@ -610,8 +625,10 @@ protocol PointerBoundsThunkBuilder: BoundsThunkBuilder { extension PointerBoundsThunkBuilder { var nullable: Bool { return oldType.is(OptionalTypeSyntax.self) } - var newType: TypeSyntax { get throws { - return try transformType(oldType, generateSpan, isSizedBy, isParameter) } + var newType: TypeSyntax { + get throws { + return try transformType(oldType, generateSpan, isSizedBy, isParameter) + } } } @@ -650,7 +667,8 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder { } func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws - -> (FunctionSignatureSyntax, Bool) { + -> (FunctionSignatureSyntax, Bool) + { assert(returnType == nil) return try base.buildFunctionSignature(argTypes, newType) } @@ -661,11 +679,12 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder { func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax { let call = try base.buildFunctionCall(pointerArgs) - let startLabel = if generateSpan { - "_unsafeStart" - } else { - "start" - } + let startLabel = + if generateSpan { + "_unsafeStart" + } else { + "start" + } var cast = try newType if nullable { if let optType = cast.as(OptionalTypeSyntax.self) { @@ -689,7 +708,6 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder { } } - struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBoundsThunkBuilder { public let base: BoundsCheckedThunkBuilder public let index: Int @@ -703,7 +721,8 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds var generateSpan: Bool { nonescaping } func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws - -> (FunctionSignatureSyntax, Bool) { + -> (FunctionSignatureSyntax, Bool) + { var types = argTypes types[index] = try newType if skipTrivialCount { @@ -762,12 +781,13 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds let call = try base.buildFunctionCall(args) let ptrRef = unwrapIfNullable(ExprSyntax(DeclReferenceExprSyntax(baseName: name))) - let funcName = switch (isSizedBy, isMutablePointerType(oldType)) { - case (true, true): "withUnsafeMutableBytes" - case (true, false): "withUnsafeBytes" - case (false, true): "withUnsafeMutableBufferPointer" - case (false, false): "withUnsafeBufferPointer" - } + let funcName = + switch (isSizedBy, isMutablePointerType(oldType)) { + case (true, true): "withUnsafeMutableBytes" + case (true, false): "withUnsafeBytes" + case (false, true): "withUnsafeMutableBufferPointer" + case (false, false): "withUnsafeBufferPointer" + } let unwrappedCall = ExprSyntax( """ unsafe \(ptrRef).\(raw: funcName) { \(unwrappedName) in @@ -887,290 +907,286 @@ func getParameterIndexForDeclRef( return try getParameterIndexForParamName((parameterList), ref.baseName) } -/// A macro that adds safe(r) wrappers for functions with unsafe pointer types. -/// Depends on bounds, escapability and lifetime information for each pointer. -/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape, -/// for automatic application by ClangImporter when the C declaration is annotated -/// appropriately. Moreover, it can wrap C++ APIs using unsafe C++ types like -/// std::span with APIs that use their safer Swift equivalents. -public struct SwiftifyImportMacro: PeerMacro { - static func parseEnumName(_ expr: ExprSyntax) throws -> String { - var exprLocal = expr - if let callExpr = expr.as(FunctionCallExprSyntax.self) { - exprLocal = callExpr.calledExpression - } - guard let dotExpr = exprLocal.as(MemberAccessExprSyntax.self) - else { - throw DiagnosticError( - "expected enum literal as argument, got '\(expr)'", - node: expr) - } - return dotExpr.declName.baseName.text +func parseEnumName(_ expr: ExprSyntax) throws -> String { + var exprLocal = expr + if let callExpr = expr.as(FunctionCallExprSyntax.self) { + exprLocal = callExpr.calledExpression } + guard let dotExpr = exprLocal.as(MemberAccessExprSyntax.self) else { + throw DiagnosticError( + "expected enum literal as argument, got '\(expr)'", + node: expr) + } + return dotExpr.declName.baseName.text +} - static func parseEnumArgs(_ expr: ExprSyntax) throws -> LabeledExprListSyntax { - guard let callExpr = expr.as(FunctionCallExprSyntax.self) - else { - throw DiagnosticError( - "expected call to enum constructor, got '\(expr)'", - node: expr) - } - return callExpr.arguments +func parseEnumArgs(_ expr: ExprSyntax) throws -> LabeledExprListSyntax { + guard let callExpr = expr.as(FunctionCallExprSyntax.self) else { + throw DiagnosticError( + "expected call to enum constructor, got '\(expr)'", + node: expr) } + return callExpr.arguments +} - static func getIntLiteralValue(_ expr: ExprSyntax) throws -> Int { - guard let intLiteral = expr.as(IntegerLiteralExprSyntax.self) else { - throw DiagnosticError("expected integer literal, got '\(expr)'", node: expr) - } - guard let res = intLiteral.representedLiteralValue else { - throw DiagnosticError("expected integer literal, got '\(expr)'", node: expr) - } - return res +func getIntLiteralValue(_ expr: ExprSyntax) throws -> Int { + guard let intLiteral = expr.as(IntegerLiteralExprSyntax.self) else { + throw DiagnosticError("expected integer literal, got '\(expr)'", node: expr) + } + guard let res = intLiteral.representedLiteralValue else { + throw DiagnosticError("expected integer literal, got '\(expr)'", node: expr) } + return res +} - static func getBoolLiteralValue(_ expr: ExprSyntax) throws -> Bool { - guard let boolLiteral = expr.as(BooleanLiteralExprSyntax.self) else { - throw DiagnosticError("expected boolean literal, got '\(expr)'", node: expr) - } - switch boolLiteral.literal.tokenKind { - case .keyword(.true): - return true - case .keyword(.false): - return false - default: - throw DiagnosticError("expected bool literal, got '\(expr)'", node: expr) - } +func getBoolLiteralValue(_ expr: ExprSyntax) throws -> Bool { + guard let boolLiteral = expr.as(BooleanLiteralExprSyntax.self) else { + throw DiagnosticError("expected boolean literal, got '\(expr)'", node: expr) } + switch boolLiteral.literal.tokenKind { + case .keyword(.true): + return true + case .keyword(.false): + return false + default: + throw DiagnosticError("expected bool literal, got '\(expr)'", node: expr) + } +} - static func parseSwiftifyExpr(_ expr: ExprSyntax) throws -> SwiftifyExpr { - let enumName = try parseEnumName(expr) - switch enumName { - case "param": - let argumentList = try parseEnumArgs(expr) - if argumentList.count != 1 { - throw DiagnosticError( - "expected single argument to _SwiftifyExpr.param, got \(argumentList.count) arguments", - node: expr) - } - let pointerParamIndexArg = argumentList[argumentList.startIndex] - let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg.expression) - return .param(pointerParamIndex) - case "return": return .return - case "self": return .`self` - default: +func parseSwiftifyExpr(_ expr: ExprSyntax) throws -> SwiftifyExpr { + let enumName = try parseEnumName(expr) + switch enumName { + case "param": + let argumentList = try parseEnumArgs(expr) + if argumentList.count != 1 { throw DiagnosticError( - "expected 'param', 'return', or 'self', got '\(enumName)'", + "expected single argument to _SwiftifyExpr.param, got \(argumentList.count) arguments", node: expr) } + let pointerParamIndexArg = argumentList[argumentList.startIndex] + let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg.expression) + return .param(pointerParamIndex) + case "return": return .return + case "self": return .`self` + default: + throw DiagnosticError( + "expected 'param', 'return', or 'self', got '\(enumName)'", + node: expr) } +} - static func parseCountedByEnum( - _ enumConstructorExpr: FunctionCallExprSyntax, _ signature: FunctionSignatureSyntax - ) throws -> ParamInfo { - let argumentList = enumConstructorExpr.arguments - let pointerExprArg = try getArgumentByName(argumentList, "pointer") - let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg) - let countExprArg = try getArgumentByName(argumentList, "count") - guard let countExprStringLit = countExprArg.as(StringLiteralExprSyntax.self) else { - throw DiagnosticError( - "expected string literal for 'count' parameter, got \(countExprArg)", node: countExprArg) - } - let unwrappedCountExpr = ExprSyntax(stringLiteral: countExprStringLit.representedLiteralValue!) - if let countVar = unwrappedCountExpr.as(DeclReferenceExprSyntax.self) { - // Perform this lookup here so we can override the position to point to the string literal - // instead of line 1, column 1 - do { - _ = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar) - } catch let error as DiagnosticError { - throw DiagnosticError(error.description, node: countExprStringLit, notes: error.notes) - } +func parseCountedByEnum( + _ enumConstructorExpr: FunctionCallExprSyntax, _ signature: FunctionSignatureSyntax +) throws -> ParamInfo { + let argumentList = enumConstructorExpr.arguments + let pointerExprArg = try getArgumentByName(argumentList, "pointer") + let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg) + let countExprArg = try getArgumentByName(argumentList, "count") + guard let countExprStringLit = countExprArg.as(StringLiteralExprSyntax.self) else { + throw DiagnosticError( + "expected string literal for 'count' parameter, got \(countExprArg)", node: countExprArg) + } + let unwrappedCountExpr = ExprSyntax(stringLiteral: countExprStringLit.representedLiteralValue!) + if let countVar = unwrappedCountExpr.as(DeclReferenceExprSyntax.self) { + // Perform this lookup here so we can override the position to point to the string literal + // instead of line 1, column 1 + do { + _ = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar) + } catch let error as DiagnosticError { + throw DiagnosticError(error.description, node: countExprStringLit, notes: error.notes) } - return CountedBy( - pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false, - nonescaping: false, dependencies: [], original: ExprSyntax(enumConstructorExpr)) } + return CountedBy( + pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false, + nonescaping: false, dependencies: [], original: ExprSyntax(enumConstructorExpr)) +} - static func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo { - let argumentList = enumConstructorExpr.arguments - let pointerExprArg = try getArgumentByName(argumentList, "pointer") - let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg) - let sizeExprArg = try getArgumentByName(argumentList, "size") - guard let sizeExprStringLit = sizeExprArg.as(StringLiteralExprSyntax.self) else { - throw DiagnosticError( - "expected string literal for 'size' parameter, got \(sizeExprArg)", node: sizeExprArg) - } - let unwrappedCountExpr = ExprSyntax(stringLiteral: sizeExprStringLit.representedLiteralValue!) - return CountedBy( - pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true, nonescaping: false, - dependencies: [], original: ExprSyntax(enumConstructorExpr)) - } - - static func parseEndedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo { - let argumentList = enumConstructorExpr.arguments - let startPointerExprArg = try getArgumentByName(argumentList, "start") - let _: SwiftifyExpr = try parseSwiftifyExpr(startPointerExprArg) - let endPointerExprArg = try getArgumentByName(argumentList, "end") - let _: SwiftifyExpr = try parseSwiftifyExpr(endPointerExprArg) - throw RuntimeError("endedBy support not yet implemented") - } - - static func parseNonEscaping(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> Int { - let argumentList = enumConstructorExpr.arguments - let pointerExprArg = try getArgumentByName(argumentList, "pointer") - let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg) - let pointerParamIndex: Int = paramOrReturnIndex(pointerExpr) - return pointerParamIndex - } - - static func parseLifetimeDependence(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> (SwiftifyExpr, LifetimeDependence) { - let argumentList = enumConstructorExpr.arguments - let pointer: SwiftifyExpr = try parseSwiftifyExpr(try getArgumentByName(argumentList, "pointer")) - let dependsOnArg = try getArgumentByName(argumentList, "dependsOn") - let dependsOn: SwiftifyExpr = try parseSwiftifyExpr(dependsOnArg) - if dependsOn == .`return` { - throw DiagnosticError("lifetime cannot depend on the return value", node: dependsOnArg) - } - let type = try getArgumentByName(argumentList, "type") - let depType: DependenceType - switch try parseEnumName(type) { - case "borrow": - depType = DependenceType.borrow - case "copy": - depType = DependenceType.copy - default: - throw DiagnosticError("expected '.copy' or '.borrow', got '\(type)'", node: type) - } - let dependence = LifetimeDependence(dependsOn: dependsOn, type: depType) - return (pointer, dependence) +func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo { + let argumentList = enumConstructorExpr.arguments + let pointerExprArg = try getArgumentByName(argumentList, "pointer") + let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg) + let sizeExprArg = try getArgumentByName(argumentList, "size") + guard let sizeExprStringLit = sizeExprArg.as(StringLiteralExprSyntax.self) else { + throw DiagnosticError( + "expected string literal for 'size' parameter, got \(sizeExprArg)", node: sizeExprArg) } + let unwrappedCountExpr = ExprSyntax(stringLiteral: sizeExprStringLit.representedLiteralValue!) + return CountedBy( + pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true, nonescaping: false, + dependencies: [], original: ExprSyntax(enumConstructorExpr)) +} - static func parseTypeMappingParam(_ paramAST: LabeledExprSyntax?) throws -> [String: String]? { - guard let unwrappedParamAST = paramAST else { - return nil - } - let paramExpr = unwrappedParamAST.expression - guard let dictExpr = paramExpr.as(DictionaryExprSyntax.self) else { - return nil - } - var dict : [String: String] = [:] - switch dictExpr.content { - case .colon(_): - return dict - case .elements(let types): - for element in types { - guard let key = element.key.as(StringLiteralExprSyntax.self) else { - throw DiagnosticError("expected a string literal, got '\(element.key)'", node: element.key) - } - guard let value = element.value.as(StringLiteralExprSyntax.self) else { - throw DiagnosticError("expected a string literal, got '\(element.value)'", node: element.value) - } - dict[key.representedLiteralValue!] = value.representedLiteralValue! - } - @unknown default: - throw DiagnosticError("unknown dictionary literal", node: dictExpr) - } - return dict +func parseEndedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo { + let argumentList = enumConstructorExpr.arguments + let startPointerExprArg = try getArgumentByName(argumentList, "start") + let _: SwiftifyExpr = try parseSwiftifyExpr(startPointerExprArg) + let endPointerExprArg = try getArgumentByName(argumentList, "end") + let _: SwiftifyExpr = try parseSwiftifyExpr(endPointerExprArg) + throw RuntimeError("endedBy support not yet implemented") +} + +func parseNonEscaping(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> Int { + let argumentList = enumConstructorExpr.arguments + let pointerExprArg = try getArgumentByName(argumentList, "pointer") + let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg) + let pointerParamIndex: Int = paramOrReturnIndex(pointerExpr) + return pointerParamIndex +} + +func parseLifetimeDependence(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ( + SwiftifyExpr, LifetimeDependence +) { + let argumentList = enumConstructorExpr.arguments + let pointer: SwiftifyExpr = try parseSwiftifyExpr(try getArgumentByName(argumentList, "pointer")) + let dependsOnArg = try getArgumentByName(argumentList, "dependsOn") + let dependsOn: SwiftifyExpr = try parseSwiftifyExpr(dependsOnArg) + if dependsOn == .`return` { + throw DiagnosticError("lifetime cannot depend on the return value", node: dependsOnArg) + } + let type = try getArgumentByName(argumentList, "type") + let depType: DependenceType + switch try parseEnumName(type) { + case "borrow": + depType = DependenceType.borrow + case "copy": + depType = DependenceType.copy + default: + throw DiagnosticError("expected '.copy' or '.borrow', got '\(type)'", node: type) } + let dependence = LifetimeDependence(dependsOn: dependsOn, type: depType) + return (pointer, dependence) +} - static func parseCxxSpansInSignature( - _ signature: FunctionSignatureSyntax, - _ typeMappings: [String: String]? - ) throws -> [ParamInfo] { - guard let typeMappings else { - return [] +func parseTypeMappingParam(_ paramAST: LabeledExprSyntax?) throws -> [String: String]? { + guard let unwrappedParamAST = paramAST else { + return nil + } + let paramExpr = unwrappedParamAST.expression + guard let dictExpr = paramExpr.as(DictionaryExprSyntax.self) else { + return nil + } + var dict: [String: String] = [:] + switch dictExpr.content { + case .colon(_): + return dict + case .elements(let types): + for element in types { + guard let key = element.key.as(StringLiteralExprSyntax.self) else { + throw DiagnosticError("expected a string literal, got '\(element.key)'", node: element.key) + } + guard let value = element.value.as(StringLiteralExprSyntax.self) else { + throw DiagnosticError( + "expected a string literal, got '\(element.value)'", node: element.value) + } + dict[key.representedLiteralValue!] = value.representedLiteralValue! } - var result : [ParamInfo] = [] - let process : (TypeSyntax, SwiftifyExpr, SyntaxProtocol) throws -> () = { type, expr, orig in - let typeName = getUnattributedType(type).description - if let desugaredType = typeMappings[typeName] { - if let unqualifiedDesugaredType = getUnqualifiedStdName(desugaredType) { - if unqualifiedDesugaredType.starts(with: "span<") { - result.append(CxxSpan(pointerIndex: expr, nonescaping: false, + @unknown default: + throw DiagnosticError("unknown dictionary literal", node: dictExpr) + } + return dict +} + +func parseCxxSpansInSignature( + _ signature: FunctionSignatureSyntax, + _ typeMappings: [String: String]? +) throws -> [ParamInfo] { + guard let typeMappings else { + return [] + } + var result: [ParamInfo] = [] + let process: (TypeSyntax, SwiftifyExpr, SyntaxProtocol) throws -> Void = { type, expr, orig in + let typeName = getUnattributedType(type).description + if let desugaredType = typeMappings[typeName] { + if let unqualifiedDesugaredType = getUnqualifiedStdName(desugaredType) { + if unqualifiedDesugaredType.starts(with: "span<") { + result.append( + CxxSpan( + pointerIndex: expr, nonescaping: false, dependencies: [], typeMappings: typeMappings, original: orig)) - } } } } - for (idx, param) in signature.parameterClause.parameters.enumerated() { - try process(param.type, .param(idx + 1), param) - } - if let retClause = signature.returnClause { - try process(retClause.type, .`return`, retClause) - } - return result } + for (idx, param) in signature.parameterClause.parameters.enumerated() { + try process(param.type, .param(idx + 1), param) + } + if let retClause = signature.returnClause { + try process(retClause.type, .`return`, retClause) + } + return result +} - static func parseMacroParam( - _ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax, - nonescapingPointers: inout Set, - lifetimeDependencies: inout [SwiftifyExpr: [LifetimeDependence]] - ) throws -> ParamInfo? { - let paramExpr = paramAST.expression - guard let enumConstructorExpr = paramExpr.as(FunctionCallExprSyntax.self) else { - throw DiagnosticError( - "expected _SwiftifyInfo enum literal as argument, got '\(paramExpr)'", node: paramExpr) - } - let enumName = try parseEnumName(paramExpr) - switch enumName { - case "countedBy": return try parseCountedByEnum(enumConstructorExpr, signature) - case "sizedBy": return try parseSizedByEnum(enumConstructorExpr) - case "endedBy": return try parseEndedByEnum(enumConstructorExpr) - case "nonescaping": - let index = try parseNonEscaping(enumConstructorExpr) - nonescapingPointers.insert(index) - return nil - case "lifetimeDependence": - let (expr, dependence) = try parseLifetimeDependence(enumConstructorExpr) - lifetimeDependencies[expr, default: []].append(dependence) - // We assume pointers annotated with lifetimebound do not escape. - let fromIdx = paramOrReturnIndex(dependence.dependsOn) - if dependence.type == DependenceType.copy && fromIdx != 0 { - nonescapingPointers.insert(fromIdx) - } - // The escaping is controlled when a parameter is the target of a lifetimebound. - // So we want to do the transformation to Swift's Span. - let idx = paramOrReturnIndex(expr) - if idx != -1 { - nonescapingPointers.insert(idx) - } - return nil - default: - throw DiagnosticError( - "expected 'countedBy', 'sizedBy', 'endedBy', 'nonescaping' or 'lifetimeDependence', got '\(enumName)'", - node: enumConstructorExpr) +func parseMacroParam( + _ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax, + nonescapingPointers: inout Set, + lifetimeDependencies: inout [SwiftifyExpr: [LifetimeDependence]] +) throws -> ParamInfo? { + let paramExpr = paramAST.expression + guard let enumConstructorExpr = paramExpr.as(FunctionCallExprSyntax.self) else { + throw DiagnosticError( + "expected _SwiftifyInfo enum literal as argument, got '\(paramExpr)'", node: paramExpr) + } + let enumName = try parseEnumName(paramExpr) + switch enumName { + case "countedBy": return try parseCountedByEnum(enumConstructorExpr, signature) + case "sizedBy": return try parseSizedByEnum(enumConstructorExpr) + case "endedBy": return try parseEndedByEnum(enumConstructorExpr) + case "nonescaping": + let index = try parseNonEscaping(enumConstructorExpr) + nonescapingPointers.insert(index) + return nil + case "lifetimeDependence": + let (expr, dependence) = try parseLifetimeDependence(enumConstructorExpr) + lifetimeDependencies[expr, default: []].append(dependence) + // We assume pointers annotated with lifetimebound do not escape. + let fromIdx = paramOrReturnIndex(dependence.dependsOn) + if dependence.type == DependenceType.copy && fromIdx != 0 { + nonescapingPointers.insert(fromIdx) + } + // The escaping is controlled when a parameter is the target of a lifetimebound. + // So we want to do the transformation to Swift's Span. + let idx = paramOrReturnIndex(expr) + if idx != -1 { + nonescapingPointers.insert(idx) } + return nil + default: + throw DiagnosticError( + "expected 'countedBy', 'sizedBy', 'endedBy', 'nonescaping' or 'lifetimeDependence', got '\(enumName)'", + node: enumConstructorExpr) } +} - static func hasTrivialCountVariants(_ parsedArgs: [ParamInfo]) -> Bool { - let countExprs = parsedArgs.compactMap { - switch $0 { - case let c as CountedBy: return c.count - default: return nil - } - } - let trivialCounts = countExprs.filter { - $0.is(DeclReferenceExprSyntax.self) || $0.is(IntegerLiteralExprSyntax.self) - } - // don't generate trivial count variants if there are any non-trivial counts - if trivialCounts.count < countExprs.count { - return false +func hasTrivialCountVariants(_ parsedArgs: [ParamInfo]) -> Bool { + let countExprs = parsedArgs.compactMap { + switch $0 { + case let c as CountedBy: return c.count + default: return nil } - let countVars = trivialCounts.filter { $0.is(DeclReferenceExprSyntax.self) } - let distinctCountVars = Set( - countVars.map { - return $0.as(DeclReferenceExprSyntax.self)!.baseName.text - }) - // don't generate trivial count variants if two count expressions refer to the same parameter - return countVars.count == distinctCountVars.count - } - - static func checkArgs(_ args: [ParamInfo], _ funcDecl: FunctionDeclSyntax) throws { - var argByIndex: [Int: ParamInfo] = [:] - var ret: ParamInfo? = nil - let paramCount = funcDecl.signature.parameterClause.parameters.count - try args.forEach { pointerInfo in - switch pointerInfo.pointerIndex { - case .param(let i): + } + let trivialCounts = countExprs.filter { + $0.is(DeclReferenceExprSyntax.self) || $0.is(IntegerLiteralExprSyntax.self) + } + // don't generate trivial count variants if there are any non-trivial counts + if trivialCounts.count < countExprs.count { + return false + } + let countVars = trivialCounts.filter { $0.is(DeclReferenceExprSyntax.self) } + let distinctCountVars = Set( + countVars.map { + return $0.as(DeclReferenceExprSyntax.self)!.baseName.text + }) + // don't generate trivial count variants if two count expressions refer to the same parameter + return countVars.count == distinctCountVars.count +} + +func checkArgs(_ args: [ParamInfo], _ funcDecl: FunctionDeclSyntax) throws { + var argByIndex: [Int: ParamInfo] = [:] + var ret: ParamInfo? = nil + let paramCount = funcDecl.signature.parameterClause.parameters.count + try args.forEach { pointerInfo in + switch pointerInfo.pointerIndex { + case .param(let i): if i < 1 || i > paramCount { let noteMessage = paramCount > 0 @@ -1188,133 +1204,160 @@ public struct SwiftifyImportMacro: PeerMacro { + "\(i): \(pointerInfo) and \(argByIndex[i]!)", node: pointerInfo.original) } argByIndex[i] = pointerInfo - case .return: + case .return: if ret != nil { throw DiagnosticError( - "multiple _SwiftifyInfos referring to return value: \(pointerInfo) and \(ret!)", node: pointerInfo.original) + "multiple _SwiftifyInfos referring to return value: \(pointerInfo) and \(ret!)", + node: pointerInfo.original) } ret = pointerInfo - case .self: - throw DiagnosticError("do not annotate self", node: pointerInfo.original) - } + case .self: + throw DiagnosticError("do not annotate self", node: pointerInfo.original) } } +} - static func paramOrReturnIndex(_ expr: SwiftifyExpr) -> Int { - switch expr { - case .param(let i): return i - case .`self`: return 0 - case .return: return -1 - } +func paramOrReturnIndex(_ expr: SwiftifyExpr) -> Int { + switch expr { + case .param(let i): return i + case .`self`: return 0 + case .return: return -1 } +} - static func setNonescapingPointers(_ args: inout [ParamInfo], _ nonescapingPointers: Set) { - if args.isEmpty { - return - } - for i in 0...args.count - 1 where nonescapingPointers.contains(paramOrReturnIndex(args[i].pointerIndex)) { - args[i].nonescaping = true - } +func setNonescapingPointers(_ args: inout [ParamInfo], _ nonescapingPointers: Set) { + if args.isEmpty { + return } + for i in 0...args.count - 1 + where nonescapingPointers.contains(paramOrReturnIndex(args[i].pointerIndex)) { + args[i].nonescaping = true + } +} - static func setLifetimeDependencies(_ args: inout [ParamInfo], _ lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]]) { - if args.isEmpty { - return - } - for i in 0...args.count - 1 where lifetimeDependencies.keys.contains(args[i].pointerIndex) { - args[i].dependencies = lifetimeDependencies[args[i].pointerIndex]! - } +func setLifetimeDependencies( + _ args: inout [ParamInfo], _ lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] +) { + if args.isEmpty { + return } + for i in 0...args.count - 1 where lifetimeDependencies.keys.contains(args[i].pointerIndex) { + args[i].dependencies = lifetimeDependencies[args[i].pointerIndex]! + } +} - static func getReturnLifetimeAttribute(_ funcDecl: FunctionDeclSyntax, - _ dependencies: [SwiftifyExpr: [LifetimeDependence]]) -> [AttributeListSyntax.Element] { - let returnDependencies = dependencies[.`return`, default: []] - if returnDependencies.isEmpty { - return [] - } - var args : [LabeledExprSyntax] = [] - for dependence in returnDependencies { - switch dependence.type { - case .borrow: - args.append(LabeledExprSyntax(expression: - DeclReferenceExprSyntax(baseName: TokenSyntax("borrow")))) - case .copy: - args.append(LabeledExprSyntax(expression: - DeclReferenceExprSyntax(baseName: TokenSyntax("copy")))) - } - args.append(LabeledExprSyntax(expression: - DeclReferenceExprSyntax(baseName: TokenSyntax(tryGetParamName(funcDecl, dependence.dependsOn))!), +func getReturnLifetimeAttribute( + _ funcDecl: FunctionDeclSyntax, + _ dependencies: [SwiftifyExpr: [LifetimeDependence]] +) -> [AttributeListSyntax.Element] { + let returnDependencies = dependencies[.`return`, default: []] + if returnDependencies.isEmpty { + return [] + } + var args: [LabeledExprSyntax] = [] + for dependence in returnDependencies { + switch dependence.type { + case .borrow: + args.append( + LabeledExprSyntax( + expression: + DeclReferenceExprSyntax(baseName: TokenSyntax("borrow")))) + case .copy: + args.append( + LabeledExprSyntax( + expression: + DeclReferenceExprSyntax(baseName: TokenSyntax("copy")))) + } + args.append( + LabeledExprSyntax( + expression: + DeclReferenceExprSyntax( + baseName: TokenSyntax(tryGetParamName(funcDecl, dependence.dependsOn))!), trailingComma: .commaToken())) - } - args[args.count - 1] = args[args.count - 1].with(\.trailingComma, nil) - return [.attribute(AttributeSyntax( - atSign: .atSignToken(), - attributeName: IdentifierTypeSyntax(name: "lifetime"), - leftParen: .leftParenToken(), - arguments: .argumentList(LabeledExprListSyntax(args)), - rightParen: .rightParenToken()))] } + args[args.count - 1] = args[args.count - 1].with(\.trailingComma, nil) + return [ + .attribute( + AttributeSyntax( + atSign: .atSignToken(), + attributeName: IdentifierTypeSyntax(name: "lifetime"), + leftParen: .leftParenToken(), + arguments: .argumentList(LabeledExprListSyntax(args)), + rightParen: .rightParenToken())) + ] +} - static func isMutableSpan(_ type: TypeSyntax) -> Bool { - if let optType = type.as(OptionalTypeSyntax.self) { - return isMutableSpan(optType.wrappedType) - } - if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) { - return isMutableSpan(impOptType.wrappedType) - } - if let attrType = type.as(AttributedTypeSyntax.self) { - return isMutableSpan(attrType.baseType) - } - guard let identifierType = type.as(IdentifierTypeSyntax.self) else { - return false - } - let name = identifierType.name.text - return name == "MutableSpan" || name == "MutableRawSpan" +func isMutableSpan(_ type: TypeSyntax) -> Bool { + if let optType = type.as(OptionalTypeSyntax.self) { + return isMutableSpan(optType.wrappedType) } - - static func containsLifetimeAttr(_ attrs: AttributeListSyntax, for paramName: TokenSyntax) -> Bool { - for elem in attrs { - guard let attr = elem.as(AttributeSyntax.self) else { - continue - } - if attr.attributeName != "lifetime" { - continue - } - guard let args = attr.arguments?.as(LabeledExprListSyntax.self) else { - continue - } - for arg in args { - if arg.label == paramName { - return true - } - } - } + if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) { + return isMutableSpan(impOptType.wrappedType) + } + if let attrType = type.as(AttributedTypeSyntax.self) { + return isMutableSpan(attrType.baseType) + } + guard let identifierType = type.as(IdentifierTypeSyntax.self) else { return false } + let name = identifierType.name.text + return name == "MutableSpan" || name == "MutableRawSpan" +} - // Mutable[Raw]Span parameters need explicit @lifetime annotations since they are inout - static func paramLifetimeAttributes(_ newSignature: FunctionSignatureSyntax, _ oldAttrs: AttributeListSyntax) -> [AttributeListSyntax.Element] { - var defaultLifetimes: [AttributeListSyntax.Element] = [] - for param in newSignature.parameterClause.parameters { - if !isMutableSpan(param.type) { - continue - } - let paramName = param.secondName ?? param.firstName - if containsLifetimeAttr(oldAttrs, for: paramName) { - continue +func containsLifetimeAttr(_ attrs: AttributeListSyntax, for paramName: TokenSyntax) -> Bool { + for elem in attrs { + guard let attr = elem.as(AttributeSyntax.self) else { + continue + } + if attr.attributeName != "lifetime" { + continue + } + guard let args = attr.arguments?.as(LabeledExprListSyntax.self) else { + continue + } + for arg in args { + if arg.label == paramName { + return true } - let expr = ExprSyntax("\(paramName): copy \(paramName)") - - defaultLifetimes.append(.attribute(AttributeSyntax( - atSign: .atSignToken(), - attributeName: IdentifierTypeSyntax(name: "lifetime"), - leftParen: .leftParenToken(), - arguments: .argumentList(LabeledExprListSyntax([LabeledExprSyntax(expression: expr)])), - rightParen: .rightParenToken()))) } - return defaultLifetimes } + return false +} +// Mutable[Raw]Span parameters need explicit @lifetime annotations since they are inout +func paramLifetimeAttributes( + _ newSignature: FunctionSignatureSyntax, _ oldAttrs: AttributeListSyntax +) -> [AttributeListSyntax.Element] { + var defaultLifetimes: [AttributeListSyntax.Element] = [] + for param in newSignature.parameterClause.parameters { + if !isMutableSpan(param.type) { + continue + } + let paramName = param.secondName ?? param.firstName + if containsLifetimeAttr(oldAttrs, for: paramName) { + continue + } + let expr = ExprSyntax("\(paramName): copy \(paramName)") + + defaultLifetimes.append( + .attribute( + AttributeSyntax( + atSign: .atSignToken(), + attributeName: IdentifierTypeSyntax(name: "lifetime"), + leftParen: .leftParenToken(), + arguments: .argumentList(LabeledExprListSyntax([LabeledExprSyntax(expression: expr)])), + rightParen: .rightParenToken()))) + } + return defaultLifetimes +} + +/// A macro that adds safe(r) wrappers for functions with unsafe pointer types. +/// Depends on bounds, escapability and lifetime information for each pointer. +/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape, +/// for automatic application by ClangImporter when the C declaration is annotated +/// appropriately. Moreover, it can wrap C++ APIs using unsafe C++ types like +/// std::span with APIs that use their safer Swift equivalents. +public struct SwiftifyImportMacro: PeerMacro { public static func expansion( of node: AttributeSyntax, providingPeersOf declaration: some DeclSyntaxProtocol, @@ -1326,15 +1369,16 @@ public struct SwiftifyImportMacro: PeerMacro { } let argumentList = node.arguments!.as(LabeledExprListSyntax.self)! - var arguments = Array(argumentList) + var arguments = [LabeledExprSyntax](argumentList) let typeMappings = try parseTypeMappingParam(arguments.last) if typeMappings != nil { arguments = arguments.dropLast() } var nonescapingPointers = Set() - var lifetimeDependencies : [SwiftifyExpr: [LifetimeDependence]] = [:] + var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:] var parsedArgs = try arguments.compactMap { - try parseMacroParam($0, funcDecl.signature, nonescapingPointers: &nonescapingPointers, + try parseMacroParam( + $0, funcDecl.signature, nonescapingPointers: &nonescapingPointers, lifetimeDependencies: &lifetimeDependencies) } parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcDecl.signature, typeMappings)) @@ -1383,13 +1427,16 @@ public struct SwiftifyImportMacro: PeerMacro { expression: try builder.buildFunctionCall([:])))) let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call])) let returnLifetimeAttribute = getReturnLifetimeAttribute(funcDecl, lifetimeDependencies) - let lifetimeAttrs = returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcDecl.attributes) - let disfavoredOverload : [AttributeListSyntax.Element] = (onlyReturnTypeChanged ? [ - .attribute( - AttributeSyntax( - atSign: .atSignToken(), - attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload"))) - ] : []) + let lifetimeAttrs = + returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcDecl.attributes) + let disfavoredOverload: [AttributeListSyntax.Element] = + (onlyReturnTypeChanged + ? [ + .attribute( + AttributeSyntax( + atSign: .atSignToken(), + attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload"))) + ] : []) let newFunc = funcDecl .with(\.signature, newSignature) @@ -1410,8 +1457,8 @@ public struct SwiftifyImportMacro: PeerMacro { atSign: .atSignToken(), attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient"))) ] - + lifetimeAttrs - + disfavoredOverload) + + lifetimeAttrs + + disfavoredOverload) return [DeclSyntax(newFunc)] } catch let error as DiagnosticError { context.diagnose(