@@ -3,9 +3,12 @@ package client
33import (
44 "bytes"
55 "context" // nolint:gosec
6+ "crypto/ed25519"
7+ "encoding/base64"
68 "encoding/json"
79 "fmt"
810 "io"
11+ "math/rand"
912 "net/http"
1013 "net/http/httputil"
1114 "net/url"
@@ -16,7 +19,7 @@ import (
1619
1720 "github.com/matrix-org/gomatrixserverlib"
1821 "github.com/tidwall/gjson"
19- "maunium.net/go/mautrix/ crypto/olm "
22+ "golang.org/x/ crypto/curve25519 "
2023
2124 "github.com/matrix-org/complement/b"
2225 "github.com/matrix-org/complement/ct"
@@ -28,6 +31,12 @@ const (
2831 CtxKeyWithRetryUntil ctxKey = "complement_retry_until" // contains *retryUntilParams
2932)
3033
34+ var (
35+ // use a deterministic seed but globally so we don't generate the same numbers for each client.
36+ // This could be non-deterministic if used concurrently.
37+ prng = rand .New (rand .NewSource (42 ))
38+ )
39+
3140type retryUntilParams struct {
3241 timeout time.Duration
3342 untilFn func (* http.Response ) bool
@@ -403,10 +412,21 @@ func (c *CSAPI) MustUploadKeys(t ct.TestLike, deviceKeys map[string]interface{},
403412 return s .OTKCounts
404413}
405414
415+ // Generate realistic looking device keys and OTKs. They are not guaranteed to be 100% valid, but should
416+ // pass most server-side checks. Critically, these keys are generated using a Pseudo-Random Number Generator (PRNG)
417+ // for determinism and hence ARE NOT SECURE. DO NOT USE THIS OUTSIDE OF TESTS.
406418func (c * CSAPI ) MustGenerateOneTimeKeys (t ct.TestLike , otkCount uint ) (deviceKeys map [string ]interface {}, oneTimeKeys map [string ]interface {}) {
407419 t .Helper ()
408- account := olm .NewAccount ()
409- ed25519Key , curveKey := account .IdentityKeys ()
420+ ed25519PubKey , ed25519PrivKey , err := ed25519 .GenerateKey (prng )
421+ if err != nil {
422+ ct .Fatalf (t , "failed to generate ed25519 key: %s" , err )
423+ }
424+
425+ curveKey := make ([]byte , 32 )
426+ _ , err = prng .Read (curveKey )
427+ if err != nil {
428+ ct .Fatalf (t , "failed to read from prng: %s" , err )
429+ }
410430
411431 ed25519KeyID := fmt .Sprintf ("ed25519:%s" , c .DeviceID )
412432 curveKeyID := fmt .Sprintf ("curve25519:%s" , c .DeviceID )
@@ -416,38 +436,59 @@ func (c *CSAPI) MustGenerateOneTimeKeys(t ct.TestLike, otkCount uint) (deviceKey
416436 "device_id" : c .DeviceID ,
417437 "algorithms" : []interface {}{"m.olm.v1.curve25519-aes-sha2" , "m.megolm.v1.aes-sha2" },
418438 "keys" : map [string ]interface {}{
419- ed25519KeyID : ed25519Key . String ( ),
420- curveKeyID : curveKey . String ( ),
439+ ed25519KeyID : base64 . RawStdEncoding . EncodeToString ( ed25519PubKey ),
440+ curveKeyID : base64 . RawStdEncoding . EncodeToString ( curveKey ),
421441 },
422442 }
423443
424- signature , _ := account .SignJSON (deviceKeys )
444+ signJSON := func (input any ) []byte {
445+ inputJSON , err := json .Marshal (input )
446+ if err != nil {
447+ ct .Fatalf (t , "failed to marshal struct: %s" , err )
448+ }
449+ inputJSON , err = gomatrixserverlib .CanonicalJSON (inputJSON )
450+ if err != nil {
451+ ct .Fatalf (t , "failed to canonical json: %s" , err )
452+ }
453+ signature := ed25519 .Sign (ed25519PrivKey , inputJSON )
454+ if err != nil {
455+ ct .Fatalf (t , "failed to sign json: %s" , err )
456+ }
457+ return signature
458+ }
425459
426460 deviceKeys ["signatures" ] = map [string ]interface {}{
427461 c .UserID : map [string ]interface {}{
428- ed25519KeyID : signature ,
462+ ed25519KeyID : base64 . RawStdEncoding . EncodeToString ( signJSON ( deviceKeys )) ,
429463 },
430464 }
431-
432- account .GenOneTimeKeys (otkCount )
433465 oneTimeKeys = map [string ]interface {}{}
434466
435- for kid , key := range account .OneTimeKeys () {
467+ for i := uint (0 ); i < otkCount ; i ++ {
468+ privateKeyBytes := make ([]byte , 32 )
469+ _ , err = prng .Read (privateKeyBytes )
470+ if err != nil {
471+ ct .Fatalf (t , "failed to read from prng" , err )
472+ }
473+ key , err := curve25519 .X25519 (privateKeyBytes , curve25519 .Basepoint )
474+ if err != nil {
475+ ct .Fatalf (t , "failed to generate curve pubkey: %s" , err )
476+ }
477+ kid := fmt .Sprintf ("%d" , i )
436478 keyID := fmt .Sprintf ("signed_curve25519:%s" , kid )
437479 keyMap := map [string ]interface {}{
438- "key" : key . String ( ),
480+ "key" : base64 . RawStdEncoding . EncodeToString ( key ),
439481 }
440482
441- signature , _ = account .SignJSON (keyMap )
442-
443483 keyMap ["signatures" ] = map [string ]interface {}{
444484 c .UserID : map [string ]interface {}{
445- ed25519KeyID : signature ,
485+ ed25519KeyID : base64 . RawStdEncoding . EncodeToString ( signJSON ( keyMap )) ,
446486 },
447487 }
448488
449489 oneTimeKeys [keyID ] = keyMap
450490 }
491+
451492 return deviceKeys , oneTimeKeys
452493}
453494
0 commit comments