Skip to content

Commit df1674f

Browse files
alexandre-dauboisdunglas
authored andcommitted
feat(gofile): use templates to generate the Go file (#1666)
1 parent e18ceb9 commit df1674f

10 files changed

Lines changed: 153 additions & 116 deletions

File tree

internal/extgen/arginfo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (ag *arginfoGenerator) fixArginfoFile(stubFile string) error {
4141
return fmt.Errorf("reading arginfo file: %w", err)
4242
}
4343

44-
// TODO: Fix the zend_register_internal_class_with_flags issue
44+
// FIXME: the script generate "zend_register_internal_class_with_flags" but it is not recognized by the compiler
4545
fixedContent := strings.ReplaceAll(content,
4646
"zend_register_internal_class_with_flags(&ce, NULL, 0)",
4747
"zend_register_internal_class(&ce)")

internal/extgen/classparser.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ func (cp *classParser) parseStructField(fieldName string, field *ast.Field) phpC
155155
// check if field is a pointer (nullable)
156156
if starExpr, isPointer := field.Type.(*ast.StarExpr); isPointer {
157157
prop.IsNullable = true
158-
prop.goType = cp.typeToString(starExpr.X)
158+
prop.GoType = cp.typeToString(starExpr.X)
159159
} else {
160160
prop.IsNullable = false
161-
prop.goType = cp.typeToString(field.Type)
161+
prop.GoType = cp.typeToString(field.Type)
162162
}
163163

164-
prop.PhpType = cp.goTypeToPHPType(prop.goType)
164+
prop.PhpType = cp.goTypeToPHPType(prop.GoType)
165165

166166
return prop
167167
}
@@ -260,13 +260,13 @@ func (cp *classParser) parseMethods(filename string) (methods []phpClassMethod,
260260
return nil, fmt.Errorf("extracting Go method function: %w", err)
261261
}
262262

263-
currentMethod.goFunction = goFunc
263+
currentMethod.GoFunction = goFunc
264264

265265
validator := Validator{}
266266
phpFunc := phpFunction{
267267
Name: currentMethod.Name,
268268
Signature: currentMethod.Signature,
269-
goFunction: currentMethod.goFunction,
269+
GoFunction: currentMethod.GoFunction,
270270
Params: currentMethod.Params,
271271
ReturnType: currentMethod.ReturnType,
272272
IsReturnNullable: currentMethod.isReturnNullable,

internal/extgen/funcparser.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func (fp *FuncParser) parse(filename string) (functions []phpFunction, err error
7474
return nil, fmt.Errorf("extracting Go function: %w", err)
7575
}
7676

77-
currentPHPFunc.goFunction = goFunc
77+
currentPHPFunc.GoFunction = goFunc
7878

7979
if err := validator.validateGoFunctionSignatureWithOptions(*currentPHPFunc, false); err != nil {
8080
fmt.Printf("Warning: Go function signature mismatch for %q: %v\n", currentPHPFunc.Name, err)

internal/extgen/gofile.go

Lines changed: 44 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,33 @@
11
package extgen
22

33
import (
4+
"bytes"
5+
_ "embed"
46
"fmt"
57
"path/filepath"
68
"strings"
9+
"text/template"
10+
11+
"github.com/Masterminds/sprig/v3"
712
)
813

14+
//go:embed templates/extension.go.tpl
15+
var goFileContent string
16+
917
type GoFileGenerator struct {
1018
generator *Generator
1119
}
1220

21+
type goTemplateData struct {
22+
PackageName string
23+
BaseName string
24+
Imports []string
25+
Constants []phpConstant
26+
InternalFunctions []string
27+
Functions []phpFunction
28+
Classes []phpClass
29+
}
30+
1331
func (gg *GoFileGenerator) generate() error {
1432
filename := filepath.Join(gg.generator.BuildDir, gg.generator.BaseName+".go")
1533
content, err := gg.buildContent()
@@ -27,104 +45,47 @@ func (gg *GoFileGenerator) buildContent() (string, error) {
2745
return "", fmt.Errorf("analyzing source file: %w", err)
2846
}
2947

30-
var builder strings.Builder
31-
32-
cleanPackageName := SanitizePackageName(gg.generator.BaseName)
33-
builder.WriteString(fmt.Sprintf(`package %s
34-
35-
/*
36-
#include <stdlib.h>
37-
#include "%s.h"
38-
*/
39-
import "C"
40-
import "runtime/cgo"
41-
`, cleanPackageName, gg.generator.BaseName))
42-
48+
filteredImports := make([]string, 0, len(imports))
4349
for _, imp := range imports {
44-
if imp == `"C"` {
45-
continue
50+
if imp != `"C"` {
51+
filteredImports = append(filteredImports, imp)
4652
}
47-
48-
builder.WriteString(fmt.Sprintf("import %s\n", imp))
4953
}
5054

51-
builder.WriteString(`
52-
func init() {
53-
frankenphp.RegisterExtension(unsafe.Pointer(&C.ext_module_entry))
54-
}
55-
`)
56-
57-
for _, constant := range gg.generator.Constants {
58-
builder.WriteString(fmt.Sprintf("const %s = %s\n", constant.Name, constant.Value))
59-
}
60-
61-
if len(gg.generator.Constants) > 0 {
62-
builder.WriteString("\n")
63-
}
64-
65-
for _, internalFunc := range internalFunctions {
66-
builder.WriteString(internalFunc + "\n\n")
67-
}
68-
69-
for _, fn := range gg.generator.Functions {
70-
builder.WriteString(fmt.Sprintf("//export %s\n%s\n", fn.Name, fn.goFunction))
71-
}
72-
73-
for _, class := range gg.generator.Classes {
74-
builder.WriteString(fmt.Sprintf("type %s struct {\n", class.GoStruct))
75-
for _, prop := range class.Properties {
76-
builder.WriteString(fmt.Sprintf(" %s %s\n", prop.Name, prop.goType))
55+
classes := make([]phpClass, len(gg.generator.Classes))
56+
copy(classes, gg.generator.Classes)
57+
for i, class := range classes {
58+
for j, method := range class.Methods {
59+
classes[i].Methods[j].Wrapper = gg.generateMethodWrapper(method, class)
7760
}
78-
builder.WriteString("}\n\n")
7961
}
8062

81-
if len(gg.generator.Classes) > 0 {
82-
builder.WriteString(`
83-
//export registerGoObject
84-
func registerGoObject(obj interface{}) C.uintptr_t {
85-
handle := cgo.NewHandle(obj)
86-
return C.uintptr_t(handle)
87-
}
88-
89-
//export getGoObject
90-
func getGoObject(handle C.uintptr_t) interface{} {
91-
h := cgo.Handle(handle)
92-
return h.value()
93-
}
94-
95-
//export removeGoObject
96-
func removeGoObject(handle C.uintptr_t) {
97-
h := cgo.Handle(handle)
98-
h.Delete()
99-
}
63+
templateContent, err := gg.getTemplateContent(goTemplateData{
64+
PackageName: SanitizePackageName(gg.generator.BaseName),
65+
BaseName: gg.generator.BaseName,
66+
Imports: filteredImports,
67+
Constants: gg.generator.Constants,
68+
InternalFunctions: internalFunctions,
69+
Functions: gg.generator.Functions,
70+
Classes: classes,
71+
})
10072

101-
`)
73+
if err != nil {
74+
return "", fmt.Errorf("executing template: %w", err)
10275
}
10376

104-
for _, class := range gg.generator.Classes {
105-
builder.WriteString(fmt.Sprintf(`//export create_%s_object
106-
func create_%s_object() C.uintptr_t {
107-
obj := &%s{}
108-
return registerGoObject(obj)
77+
return templateContent, nil
10978
}
11079

111-
`, class.GoStruct, class.GoStruct, class.GoStruct))
80+
func (gg *GoFileGenerator) getTemplateContent(data goTemplateData) (string, error) {
81+
tmpl := template.Must(template.New("gofile").Funcs(sprig.FuncMap()).Parse(goFileContent))
11282

113-
for _, method := range class.Methods {
114-
if method.goFunction != "" {
115-
builder.WriteString(method.goFunction)
116-
builder.WriteString("\n\n")
117-
}
118-
}
119-
120-
for _, method := range class.Methods {
121-
builder.WriteString(fmt.Sprintf("//export %s_wrapper\n", method.Name))
122-
builder.WriteString(gg.generateMethodWrapper(method, class))
123-
builder.WriteString("\n")
124-
}
83+
var buf bytes.Buffer
84+
if err := tmpl.Execute(&buf, data); err != nil {
85+
return "", err
12586
}
12687

127-
return builder.String(), nil
88+
return buf.String(), nil
12889
}
12990

13091
func (gg *GoFileGenerator) generateMethodWrapper(method phpClassMethod, class phpClass) string {

internal/extgen/gofile_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ func anotherHelper() {
5151
{
5252
Name: "greet",
5353
ReturnType: "string",
54-
goFunction: `func greet(name *go_string) *go_value {
54+
GoFunction: `func greet(name *go_string) *go_value {
5555
return types.String("Hello " + CStringToGoString(name))
5656
}`,
5757
},
5858
{
5959
Name: "calculate",
6060
ReturnType: "int",
61-
goFunction: `func calculate(a long, b long) *go_value {
61+
GoFunction: `func calculate(a long, b long) *go_value {
6262
result := a + b
6363
return types.Int(result)
6464
}`,
@@ -103,7 +103,7 @@ func test() {
103103
{
104104
Name: "test",
105105
ReturnType: "void",
106-
goFunction: "func test() {\n\t// simple function\n}",
106+
GoFunction: "func test() {\n\t// simple function\n}",
107107
},
108108
},
109109
contains: []string{
@@ -136,7 +136,7 @@ func process(data *go_string) *go_value {
136136
{
137137
Name: "process",
138138
ReturnType: "string",
139-
goFunction: `func process(data *go_string) *go_value {
139+
GoFunction: `func process(data *go_string) *go_value {
140140
return String(fmt.Sprintf("processed: %s", CStringToGoString(data)))
141141
}`,
142142
},
@@ -169,7 +169,7 @@ func internalFunc2(data string) {
169169
{
170170
Name: "publicFunc",
171171
ReturnType: "void",
172-
goFunction: "func publicFunc() {}",
172+
GoFunction: "func publicFunc() {}",
173173
},
174174
},
175175
contains: []string{
@@ -219,7 +219,7 @@ func TestGoFileGenerator_PackageNameSanitization(t *testing.T) {
219219
BaseName: tt.baseName,
220220
SourceFile: sourceFile,
221221
Functions: []phpFunction{
222-
{Name: "test", ReturnType: "void", goFunction: "func test() {}"},
222+
{Name: "test", ReturnType: "void", GoFunction: "func test() {}"},
223223
},
224224
}
225225

@@ -296,7 +296,7 @@ func test() {}`
296296
BaseName: "importtest",
297297
SourceFile: sourceFile,
298298
Functions: []phpFunction{
299-
{Name: "test", ReturnType: "void", goFunction: "func test() {}"},
299+
{Name: "test", ReturnType: "void", GoFunction: "func test() {}"},
300300
},
301301
}
302302

@@ -372,7 +372,7 @@ func debugPrint(msg string) {
372372
{
373373
Name: "processData",
374374
ReturnType: "array",
375-
goFunction: `func processData(input *go_string, options *go_nullable) *go_value {
375+
GoFunction: `func processData(input *go_string, options *go_nullable) *go_value {
376376
data := CStringToGoString(input)
377377
processed := internalProcess(data)
378378
return Array([]interface{}{processed})
@@ -381,7 +381,7 @@ func debugPrint(msg string) {
381381
{
382382
Name: "validateInput",
383383
ReturnType: "bool",
384-
goFunction: `func validateInput(data *go_string) *go_value {
384+
GoFunction: `func validateInput(data *go_string) *go_value {
385385
input := CStringToGoString(data)
386386
isValid := len(input) > 0 && validateFormat(input)
387387
return Bool(isValid)
@@ -459,7 +459,7 @@ func (ts *TestStruct) ProcessData(name string, count *int64, enabled *bool) stri
459459
{Name: "count", PhpType: "int", IsNullable: true},
460460
{Name: "enabled", PhpType: "bool", IsNullable: true},
461461
},
462-
goFunction: `func (ts *TestStruct) ProcessData(name string, count *int64, enabled *bool) string {
462+
GoFunction: `func (ts *TestStruct) ProcessData(name string, count *int64, enabled *bool) string {
463463
result := fmt.Sprintf("name=%s", name)
464464
if count != nil {
465465
result += fmt.Sprintf(", count=%d", *count)

internal/extgen/nodes.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
type phpFunction struct {
99
Name string
1010
Signature string
11-
goFunction string
11+
GoFunction string
1212
Params []phpParameter
1313
ReturnType string
1414
IsReturnNullable bool
@@ -34,7 +34,8 @@ type phpClassMethod struct {
3434
Name string
3535
PhpName string
3636
Signature string
37-
goFunction string
37+
GoFunction string
38+
Wrapper string
3839
Params []phpParameter
3940
ReturnType string
4041
isReturnNullable bool
@@ -45,7 +46,7 @@ type phpClassMethod struct {
4546
type phpClassProperty struct {
4647
Name string
4748
PhpType string
48-
goType string
49+
GoType string
4950
IsNullable bool
5051
}
5152

internal/extgen/stub.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ func getPhpTypeAnnotation(goType string) string {
4646
case "string", "bool", "float", "int":
4747
return goType
4848
default:
49-
return "int" // fallback
49+
return "int"
5050
}
5151
}

0 commit comments

Comments
 (0)