package playwright import ( "errors" "fmt" "reflect" "regexp" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/go-stack/stack" "github.com/playwright-community/playwright-go/internal/safe" ) var ( pkgSourcePathPattern = regexp.MustCompile(`.+[\\/]playwright-go[\\/][^\\/]+\.go`) apiNameTransform = regexp.MustCompile(`(?U)\(\*(.+)(Impl)?\)`) ) type connection struct { transport transport apiZone sync.Map objects *safe.SyncMap[string, *channelOwner] lastID atomic.Uint32 rootObject *rootChannelOwner callbacks *safe.SyncMap[uint32, *protocolCallback] afterClose func() onClose func() error isRemote bool localUtils *localUtilsImpl tracingCount atomic.Int32 abort chan struct{} abortOnce sync.Once err *safeValue[error] // for event listener error closedError *safeValue[error] } func (c *connection) Start() (*Playwright, error) { go func() { for { msg, err := c.transport.Poll() if err != nil { _ = c.transport.Close() c.cleanup(err) return } c.Dispatch(msg) } }() c.onClose = func() error { if err := c.transport.Close(); err != nil { return err } return nil } return c.rootObject.initialize() } func (c *connection) Stop() error { if err := c.onClose(); err != nil { return err } c.cleanup() return nil } func (c *connection) cleanup(cause ...error) { if len(cause) > 0 { c.closedError.Set(fmt.Errorf("%w: %w", ErrTargetClosed, cause[0])) } else { c.closedError.Set(ErrTargetClosed) } if c.afterClose != nil { c.afterClose() } c.abortOnce.Do(func() { select { case <-c.abort: default: close(c.abort) } }) } func (c *connection) Dispatch(msg *message) { if c.closedError.Get() != nil { return } method := msg.Method if msg.ID != 0 { cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID)) if cb.noReply { return } if msg.Error != nil { cb.SetError(parseError(msg.Error.Error)) } else { cb.SetResult(c.replaceGuidsWithChannels(msg.Result).(map[string]interface{})) } return } object, _ := c.objects.Load(msg.GUID) if method == "__create__" { c.createRemoteObject( object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"], ) return } if object == nil { return } if method == "__adopt__" { child, ok := c.objects.Load(msg.Params["guid"].(string)) if !ok { return } object.adopt(child) return } if method == "__dispose__" { reason, ok := msg.Params["reason"] if ok { object.dispose(reason.(string)) } else { object.dispose() } return } if object.objectType == "JsonPipe" { object.channel.Emit(method, msg.Params) } else { object.channel.Emit(method, c.replaceGuidsWithChannels(msg.Params)) } } func (c *connection) LocalUtils() *localUtilsImpl { return c.localUtils } func (c *connection) createRemoteObject(parent *channelOwner, objectType string, guid string, initializer interface{}) interface{} { initializer = c.replaceGuidsWithChannels(initializer) result := createObjectFactory(parent, objectType, guid, initializer.(map[string]interface{})) return result } func (c *connection) WrapAPICall(cb func() (interface{}, error), isInternal bool) (interface{}, error) { if _, ok := c.apiZone.Load("apiZone"); ok { return cb() } c.apiZone.Store("apiZone", serializeCallStack(isInternal)) return cb() } func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} { if payload == nil { return nil } v := reflect.ValueOf(payload) if v.Kind() == reflect.Slice { listV := payload.([]interface{}) for i := 0; i < len(listV); i++ { listV[i] = c.replaceGuidsWithChannels(listV[i]) } return listV } if v.Kind() == reflect.Map { mapV := payload.(map[string]interface{}) if guid, hasGUID := mapV["guid"]; hasGUID { if channelOwner, ok := c.objects.Load(guid.(string)); ok { return channelOwner.channel } } for key := range mapV { mapV[key] = c.replaceGuidsWithChannels(mapV[key]) } return mapV } return payload } func (c *connection) sendMessageToServer(object *channelOwner, method string, params interface{}, noReply bool) (cb *protocolCallback) { cb = newProtocolCallback(noReply, c.abort) if err := c.closedError.Get(); err != nil { cb.SetError(err) return } if object.wasCollected { cb.SetError(errors.New("The object has been collected to prevent unbounded heap growth.")) return } id := c.lastID.Add(1) c.callbacks.Store(id, cb) var ( metadata = make(map[string]interface{}, 0) stack = make([]map[string]interface{}, 0) ) apiZone, ok := c.apiZone.LoadAndDelete("apiZone") if ok { for k, v := range apiZone.(parsedStackTrace).metadata { metadata[k] = v } stack = append(stack, apiZone.(parsedStackTrace).frames...) } metadata["wallTime"] = time.Now().UnixMilli() message := map[string]interface{}{ "id": id, "guid": object.guid, "method": method, "params": params, // channel.MarshalJSON will replace channel with guid "metadata": metadata, } if c.tracingCount.Load() > 0 && len(stack) > 0 && object.guid != "localUtils" { c.LocalUtils().AddStackToTracingNoReply(id, stack) } if err := c.transport.Send(message); err != nil { cb.SetError(fmt.Errorf("could not send message: %w", err)) return } return } func (c *connection) setInTracing(isTracing bool) { if isTracing { c.tracingCount.Add(1) } else { c.tracingCount.Add(-1) } } type parsedStackTrace struct { frames []map[string]interface{} metadata map[string]interface{} } func serializeCallStack(isInternal bool) parsedStackTrace { st := stack.Trace().TrimRuntime() if len(st) == 0 { // https://github.com/go-stack/stack/issues/27 st = stack.Trace() } lastInternalIndex := 0 for i, s := range st { if pkgSourcePathPattern.MatchString(s.Frame().File) { lastInternalIndex = i } } apiName := "" if !isInternal { apiName = fmt.Sprintf("%n", st[lastInternalIndex]) } st = st.TrimBelow(st[lastInternalIndex]) callStack := make([]map[string]interface{}, 0) for i, s := range st { if i == 0 { continue } callStack = append(callStack, map[string]interface{}{ "file": s.Frame().File, "line": s.Frame().Line, "column": 0, "function": s.Frame().Function, }) } metadata := make(map[string]interface{}) if len(st) > 1 { metadata["location"] = serializeCallLocation(st[1]) } apiName = apiNameTransform.ReplaceAllString(apiName, "$1") if len(apiName) > 1 { apiName = strings.ToUpper(apiName[:1]) + apiName[1:] } metadata["apiName"] = apiName metadata["isInternal"] = isInternal return parsedStackTrace{ metadata: metadata, frames: callStack, } } func serializeCallLocation(caller stack.Call) map[string]interface{} { line, _ := strconv.Atoi(fmt.Sprintf("%d", caller)) return map[string]interface{}{ "file": fmt.Sprintf("%s", caller), "line": line, } } func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection { connection := &connection{ abort: make(chan struct{}, 1), callbacks: safe.NewSyncMap[uint32, *protocolCallback](), objects: safe.NewSyncMap[string, *channelOwner](), transport: transport, isRemote: false, err: &safeValue[error]{}, closedError: &safeValue[error]{}, } if len(localUtils) > 0 { connection.localUtils = localUtils[0] connection.isRemote = true } connection.rootObject = newRootChannelOwner(connection) return connection } func fromChannel(v interface{}) interface{} { return v.(*channel).object } func fromNullableChannel(v interface{}) interface{} { if v == nil { return nil } return fromChannel(v) } type protocolCallback struct { done chan struct{} noReply bool abort <-chan struct{} once sync.Once value map[string]interface{} err error } func (pc *protocolCallback) setResultOnce(result map[string]interface{}, err error) { pc.once.Do(func() { pc.value = result pc.err = err close(pc.done) }) } func (pc *protocolCallback) waitResult() { if pc.noReply { return } select { case <-pc.done: // wait for result return case <-pc.abort: select { case <-pc.done: return default: pc.err = errors.New("Connection closed") return } } } func (pc *protocolCallback) SetError(err error) { pc.setResultOnce(nil, err) } func (pc *protocolCallback) SetResult(result map[string]interface{}) { pc.setResultOnce(result, nil) } func (pc *protocolCallback) GetResult() (map[string]interface{}, error) { pc.waitResult() return pc.value, pc.err } // GetResultValue returns value if the map has only one element func (pc *protocolCallback) GetResultValue() (interface{}, error) { pc.waitResult() if len(pc.value) == 0 { // empty map treated as nil return nil, pc.err } if len(pc.value) == 1 { for key := range pc.value { return pc.value[key], pc.err } } return pc.value, pc.err } func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback { if noReply { return &protocolCallback{ noReply: true, abort: abort, } } return &protocolCallback{ done: make(chan struct{}, 1), abort: abort, } }