diff --git a/backend/api/api.go b/backend/api/api.go index f562b08..61db6cb 100644 --- a/backend/api/api.go +++ b/backend/api/api.go @@ -48,6 +48,7 @@ func (a *APIHandler) postHandler(c *gin.Context) { action := c.Param("postAction") remoteIP := getRemoteIp(c) loginUser := GetLoginUser(c) + hostname := getHostname(c) switch action { case "login": @@ -88,13 +89,12 @@ func (a *APIHandler) postHandler(c *gin.Context) { act := c.Request.FormValue("action") data := c.Request.FormValue("data") userLinks := c.Request.FormValue("userLinks") - outJsons := c.Request.FormValue("outJsons") - err = a.ConfigService.Save(obj, act, json.RawMessage(data), json.RawMessage(userLinks), json.RawMessage(outJsons), loginUser) + objs, err := a.ConfigService.Save(obj, act, json.RawMessage(data), json.RawMessage(userLinks), loginUser, hostname) if err != nil { jsonMsg(c, "save", err) return } - err = a.loadPartialData(c, obj, len(outJsons) > 5, len(userLinks) > 5) + err = a.loadPartialData(c, objs) if err != nil { jsonMsg(c, obj, err) } @@ -133,7 +133,7 @@ func (a *APIHandler) getHandler(c *gin.Context) { } jsonObj(c, data, nil) case "inbounds", "outbounds", "endpoints", "tls", "clients", "config": - err := a.loadPartialData(c, action, false, false) + err := a.loadPartialData(c, []string{action}) if err != nil { jsonMsg(c, action, err) } @@ -257,62 +257,51 @@ func (a *APIHandler) loadData(c *gin.Context) (interface{}, error) { return data, nil } -func (a *APIHandler) loadPartialData(c *gin.Context, obj string, plusInbounds bool, plusClients bool) error { +func (a *APIHandler) loadPartialData(c *gin.Context, objs []string) error { data := make(map[string]interface{}, 0) - switch obj { - case "inbounds": - id := c.Query("id") - inbounds, err := a.InboundService.Get(id) - if err != nil { - return err + + for _, obj := range objs { + switch obj { + case "inbounds": + id := c.Query("id") + inbounds, err := a.InboundService.Get(id) + if err != nil { + return err + } + data[obj] = inbounds + case "outbounds": + outbounds, err := a.OutboundService.GetAll() + if err != nil { + return err + } + data[obj] = outbounds + case "endpoints": + endpoints, err := a.EndpointService.GetAll() + if err != nil { + return err + } + data[obj] = endpoints + case "tls": + tlsConfigs, err := a.TlsService.GetAll() + if err != nil { + return err + } + data[obj] = tlsConfigs + case "clients": + clients, err := a.ClientService.GetAll() + if err != nil { + return err + } + data[obj] = clients + case "config": + config, err := a.SettingService.GetConfig() + if err != nil { + return err + } + data[obj] = json.RawMessage(config) } - data[obj] = inbounds - case "outbounds": - outbounds, err := a.OutboundService.GetAll() - if err != nil { - return err - } - data[obj] = outbounds - case "endpoints": - endpoints, err := a.EndpointService.GetAll() - if err != nil { - return err - } - data[obj] = endpoints - case "tls": - tlsConfigs, err := a.TlsService.GetAll() - if err != nil { - return err - } - data[obj] = tlsConfigs - case "clients": - clients, err := a.ClientService.GetAll() - if err != nil { - return err - } - data[obj] = clients - case "config": - config, err := a.SettingService.GetConfig() - if err != nil { - return err - } - data[obj] = json.RawMessage(config) } - if plusInbounds { - inbounds, err := a.InboundService.GetAll() - if err != nil { - return err - } - data["inbounds"] = inbounds - } - if plusClients { - clients, err := a.ClientService.GetAll() - if err != nil { - return err - } - data["clients"] = clients - } jsonObj(c, data, nil) return nil } diff --git a/backend/api/utils.go b/backend/api/utils.go index ab19cb4..e26230e 100644 --- a/backend/api/utils.go +++ b/backend/api/utils.go @@ -27,6 +27,14 @@ func getRemoteIp(c *gin.Context) string { } } +func getHostname(c *gin.Context) string { + host := c.Request.Host + if colonIndex := strings.LastIndex(host, ":"); colonIndex != -1 { + host, _, _ = net.SplitHostPort(c.Request.Host) + } + return host +} + func jsonMsg(c *gin.Context, msg string, err error) { jsonMsgObj(c, msg, nil, err) } diff --git a/backend/service/config.go b/backend/service/config.go index 78d1c4a..a72d151 100644 --- a/backend/service/config.go +++ b/backend/service/config.go @@ -123,7 +123,7 @@ func (s *ConfigService) StopCore() error { return nil } -func (s *ConfigService) Save(obj string, act string, data json.RawMessage, userLinks json.RawMessage, outJsons json.RawMessage, loginUser string) error { +func (s *ConfigService) Save(obj string, act string, data json.RawMessage, userLinks json.RawMessage, loginUser string, hostname string) ([]string, error) { var err error var inboundIds []uint @@ -132,16 +132,6 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, userL defer func() { if err == nil { tx.Commit() - if len(inboundIds) > 0 && corePtr.IsRunning() { - err1 := s.InboundService.RestartInbounds(tx, inboundIds) - if err1 != nil { - logger.Error("unable to restart inbounds: ", err1) - } - } - // Try to start core if it is not running - if !corePtr.IsRunning() { - s.StartCore("") - } } else { tx.Rollback() } @@ -153,7 +143,7 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, userL case "tls": inboundIds, err = s.TlsService.Save(tx, act, data) case "inbounds": - err = s.InboundService.Save(tx, act, data) + err = s.InboundService.Save(tx, act, data, hostname) case "outbounds": err = s.OutboundService.Save(tx, act, data) case "endpoints": @@ -161,27 +151,20 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, userL case "config": err = s.SettingService.SaveConfig(tx, data) if err != nil { - return err + return nil, err } err = s.restartCoreWithConfig(data) default: - return common.NewError("unknown object: ", obj) + return nil, common.NewError("unknown object: ", obj) } if err != nil { - return err + return nil, err } if len(userLinks) > 0 { err = s.ClientService.UpdateLinks(tx, userLinks) if err != nil { - return err - } - } - - if len(outJsons) > 0 { - err = s.InboundService.UpdateOutJsons(tx, outJsons) - if err != nil { - return err + return nil, err } } @@ -194,11 +177,46 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, userL Obj: data, }).Error if err != nil { - return err + return nil, err } + // Commit changes so far + tx.Commit() LastUpdate = time.Now().Unix() + var objs []string = []string{obj} + tx = db.Begin() - return nil + // Update side changes + + // Update client links + if len(userLinks) > 0 { + err = s.ClientService.UpdateLinks(tx, userLinks) + if err != nil { + return nil, err + } + objs = append(objs, "clients") + } + + // Update out_json of inbounds when tls is changed + if obj == "tls" && len(inboundIds) > 0 { + err = s.InboundService.UpdateOutJsons(tx, inboundIds, hostname) + if err != nil { + return nil, common.NewError("unable to update out_json of inbounds: ", err.Error()) + } + objs = append(objs, "inbounds") + } + + if len(inboundIds) > 0 && corePtr.IsRunning() { + err1 := s.InboundService.RestartInbounds(tx, inboundIds) + if err1 != nil { + logger.Error("unable to restart inbounds: ", err1) + } + } + // Try to start core if it is not running + if !corePtr.IsRunning() { + s.StartCore("") + } + + return objs, nil } func (s *ConfigService) CheckChanges(lu string) (bool, error) { diff --git a/backend/service/inbounds.go b/backend/service/inbounds.go index a16775e..70881ae 100644 --- a/backend/service/inbounds.go +++ b/backend/service/inbounds.go @@ -5,6 +5,7 @@ import ( "os" "s-ui/database" "s-ui/database/model" + "s-ui/util" "strings" "gorm.io/gorm" @@ -75,7 +76,7 @@ func (s *InboundService) FromIds(ids []uint) ([]*model.Inbound, error) { return inbounds, nil } -func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage) error { +func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, hostname string) error { var err error switch act { @@ -85,6 +86,13 @@ func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage) err if err != nil { return err } + if inbound.TlsId > 0 { + err = tx.Model(model.Tls{}).Where("id = ?", inbound.TlsId).Find(&inbound.Tls).Error + if err != nil { + return err + } + } + if corePtr.IsRunning() { if act == "edit" { err = corePtr.RemoveInbound(inbound.Tag) @@ -93,13 +101,6 @@ func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage) err } } - if inbound.TlsId > 0 { - err = tx.Model(model.Tls{}).Where("id = ?", inbound.TlsId).Find(&inbound.Tls).Error - if err != nil { - return err - } - } - inboundConfig, err := inbound.MarshalJSON() if err != nil { return err @@ -116,6 +117,11 @@ func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage) err } } + err = util.FillOutJson(&inbound, hostname) + if err != nil { + return err + } + err = tx.Save(&inbound).Error if err != nil { return err @@ -140,20 +146,18 @@ func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage) err return nil } -func (s *InboundService) UpdateOutJsons(tx *gorm.DB, data json.RawMessage) error { - var outJsons []interface{} - err := json.Unmarshal(data, &outJsons) +func (s *InboundService) UpdateOutJsons(tx *gorm.DB, inboundIds []uint, hostname string) error { + var inbounds []model.Inbound + err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ?", inboundIds).Find(&inbounds).Error if err != nil { return err } - for _, outJson := range outJsons { - outJsonData := outJson.(map[string]interface{}) - tag := outJsonData["tag"].(string) - outJson, err := json.MarshalIndent(outJsonData["out_json"], "", " ") + for _, inbound := range inbounds { + err = util.FillOutJson(&inbound, hostname) if err != nil { return err } - err = tx.Model(model.Inbound{}).Where("tag = ?", tag).Update("out_json", outJson).Error + err = tx.Model(model.Inbound{}).Where("tag = ?", inbound.Tag).Update("out_json", inbound.OutJson).Error if err != nil { return err } diff --git a/backend/util/outJson.go b/backend/util/outJson.go new file mode 100644 index 0000000..aaa38ea --- /dev/null +++ b/backend/util/outJson.go @@ -0,0 +1,183 @@ +package util + +import ( + "encoding/json" + "math/rand" + "s-ui/database/model" +) + +// Fill Inbound's out_json +func FillOutJson(i *model.Inbound, hostname string) error { + var outJson map[string]interface{} + err := json.Unmarshal(i.OutJson, &outJson) + if err != nil { + return err + } + + if i.TlsId > 0 { + addTls(&outJson, i.Tls) + } else { + delete(outJson, "tls") + } + + inbound, err := i.MarshalFull() + + outJson["type"] = i.Type + outJson["tag"] = i.Tag + outJson["server"] = hostname + outJson["server_port"] = (*inbound)["listen_port"] + + switch i.Type { + case "http", "socks", "mixed": + case "shadowsocks": + shadowsocksOut(&outJson, *inbound) + return nil + case "shadowtls": + shadowTlsOut(&outJson, *inbound) + case "hysteria": + hysteriaOut(&outJson, *inbound) + case "hysteria2": + hysteria2Out(&outJson, *inbound) + case "tuic": + tuicOut(&outJson, *inbound) + case "vless": + vlessOut(&outJson, *inbound) + case "trojan": + trojanOut(&outJson, *inbound) + case "vmess": + vmessOut(&outJson, *inbound) + default: + for key := range outJson { + delete(outJson, key) + } + } + + i.OutJson, err = json.MarshalIndent(outJson, "", " ") + if err != nil { + return err + } + + return nil +} + +// addTls function +func addTls(out *map[string]interface{}, tls *model.Tls) { + var tlsServer, tlsConfig map[string]interface{} + err := json.Unmarshal(tls.Server, &tlsServer) + if err != nil { + return + } + err = json.Unmarshal(tls.Client, &tlsConfig) + if err != nil { + return + } + + if enabled, ok := tlsServer["enabled"]; ok { + tlsConfig["enabled"] = enabled + } + if serverName, ok := tlsServer["server_name"]; ok { + tlsConfig["server_name"] = serverName + } + if alpn, ok := tlsServer["alpn"]; ok { + tlsConfig["alpn"] = alpn + } + if minVersion, ok := tlsServer["min_version"]; ok { + tlsConfig["min_version"] = minVersion + } + if maxVersion, ok := tlsServer["max_version"]; ok { + tlsConfig["max_version"] = maxVersion + } + if cipherSuites, ok := tlsServer["cipher_suites"]; ok { + tlsConfig["cipher_suites"] = cipherSuites + } + if reality, ok := tlsServer["reality"].(map[string]interface{}); ok && reality["enabled"].(bool) { + realityConfig := tlsConfig["reality"].(map[string]interface{}) + realityConfig["enabled"] = true + if shortIDs, ok := reality["short_id"].([]interface{}); ok && len(shortIDs) > 0 { + realityConfig["short_id"] = shortIDs[rand.Intn(len(shortIDs))] + } + tlsConfig["reality"] = realityConfig + } + + (*out)["tls"] = tlsConfig +} + +// Protocol-specific functions +func shadowsocksOut(out *map[string]interface{}, inbound map[string]interface{}) { + if method, ok := inbound["method"].(string); ok { + (*out)["method"] = method + } +} + +func shadowTlsOut(out *map[string]interface{}, inbound map[string]interface{}) { + if version, ok := inbound["version"].(float64); ok && int(version) == 3 { + (*out)["version"] = 3 + } else { + for key := range *out { + delete(*out, key) + } + } + (*out)["tls"] = map[string]interface{}{"enabled": true} +} + +func hysteriaOut(out *map[string]interface{}, inbound map[string]interface{}) { + if upMbps, ok := inbound["down_mbps"]; ok { + (*out)["up_mbps"] = upMbps + } + if downMbps, ok := inbound["up_mbps"]; ok { + (*out)["down_mbps"] = downMbps + } + if obfs, ok := inbound["obfs"]; ok { + (*out)["obfs"] = obfs + } + if recvWindow, ok := inbound["recv_window_conn"]; ok { + (*out)["recv_window_conn"] = recvWindow + } + if disableMTU, ok := inbound["disable_mtu_discovery"]; ok { + (*out)["disable_mtu_discovery"] = disableMTU + } +} + +func hysteria2Out(out *map[string]interface{}, inbound map[string]interface{}) { + if upMbps, ok := inbound["down_mbps"]; ok { + (*out)["up_mbps"] = upMbps + } + if downMbps, ok := inbound["up_mbps"]; ok { + (*out)["down_mbps"] = downMbps + } + if obfs, ok := inbound["obfs"]; ok { + (*out)["obfs"] = obfs + } +} + +func tuicOut(out *map[string]interface{}, inbound map[string]interface{}) { + if congestionControl, ok := inbound["congestion_control"].(string); ok { + (*out)["congestion_control"] = congestionControl + } else { + (*out)["congestion_control"] = "cubic" + } + if zeroRTT, ok := inbound["zero_rtt_handshake"].(bool); ok { + (*out)["zero_rtt_handshake"] = zeroRTT + } + if heartbeat, ok := inbound["heartbeat"]; ok { + (*out)["heartbeat"] = heartbeat + } +} + +func vlessOut(out *map[string]interface{}, inbound map[string]interface{}) { + if transport, ok := inbound["transport"]; ok { + (*out)["transport"] = transport + } +} + +func trojanOut(out *map[string]interface{}, inbound map[string]interface{}) { + if transport, ok := inbound["transport"]; ok { + (*out)["transport"] = transport + } +} + +func vmessOut(out *map[string]interface{}, inbound map[string]interface{}) { + if transport, ok := inbound["transport"]; ok { + (*out)["transport"] = transport + } +} diff --git a/frontend/src/plugins/outJson.ts b/frontend/src/plugins/outJson.ts deleted file mode 100644 index f9a726e..0000000 --- a/frontend/src/plugins/outJson.ts +++ /dev/null @@ -1,104 +0,0 @@ -import { Hysteria, Hysteria2, Inbound, InTypes, Shadowsocks, Trojan, TUIC, VLESS, VMess, ShadowTLS } from "@/types/inbounds" -import { iTls } from "@/types/inTls" -import { oTls } from "@/types/outTls" -import RandomUtil from "./randomUtil" - -export function fillData(inbound: Inbound, tls: any | null = null) { - if (tls != null) { - addTls(inbound.out_json, tls.server, tls.client) - } else { - delete inbound.out_json.tls - } - inbound.out_json.type = inbound.type - inbound.out_json.tag = inbound.tag - inbound.out_json.server = location.hostname - inbound.out_json.server_port = inbound.listen_port - switch(inbound.type){ - case InTypes.HTTP: case InTypes.SOCKS: case InTypes.Mixed: - return - case InTypes.Shadowsocks: - shadowsocksOut(inbound.out_json, inbound) - return - case InTypes.ShadowTLS: - shadowTlsOut(inbound.out_json, inbound) - return - case InTypes.Hysteria: - hysteriaOut(inbound.out_json, inbound) - return - case InTypes.Hysteria2: - hysteria2Out(inbound.out_json, inbound) - return - case InTypes.TUIC: - tuicOut(inbound.out_json, inbound) - return - case InTypes.VLESS: - vlessOut(inbound.out_json, inbound) - return - case InTypes.Trojan: - trojanOut(inbound.out_json, inbound) - return - case InTypes.VMess: - vmessOut(inbound.out_json, inbound) - return - } - Object.keys(inbound.out_json).forEach(key => delete inbound.out_json[key]) -} - -function addTls(out: any, tls: iTls, tlsClient: oTls){ - out.tls = tlsClient?? {} - if(tls.enabled) out.tls.enabled = tls.enabled - if(tls.server_name) out.tls.server_name = tls.server_name - if(tls.alpn) out.tls.alpn = tls.alpn - if(tls.min_version) out.tls.min_version = tls.min_version - if(tls.max_version) out.tls.max_version = tls.max_version - if(tls.cipher_suites) out.tls.cipher_suites = tls.cipher_suites - if(tls.reality?.enabled){ - out.tls.reality.enabled = true - out.tls.reality.short_id = tls.reality.short_id[RandomUtil.randomInt(tls.reality.short_id.length)] - } -} - -function shadowsocksOut(out: any, inbound: Shadowsocks) { - out.method = inbound.method -} - -function shadowTlsOut(out: any, inbound: ShadowTLS) { - if (inbound.version == 3) { - out.version = 3 - } else { - Object.keys(out).forEach(key => delete out[key]) - } - out.tls = { enabled: true } -} - -function hysteriaOut(out: any, inbound: Hysteria) { - out.up_mbps = inbound.down_mbps - out.down_mbps = inbound.up_mbps - out.obfs = inbound.obfs - out.recv_window_conn = inbound.recv_window_conn - out.disable_mtu_discovery = inbound.disable_mtu_discovery -} - -function hysteria2Out(out: any, inbound: Hysteria2) { - out.up_mbps = inbound.down_mbps - out.down_mbps = inbound.up_mbps - out.obfs = inbound.obfs -} - -function tuicOut(out: any, inbound: TUIC) { - out.congestion_control = inbound.congestion_control?? "cubic" - out.zero_rtt_handshake = inbound.zero_rtt_handshake - out.heartbeat = inbound.heartbeat -} - -function vlessOut(out: any, inbound: VLESS) { - out.transport = inbound.transport -} - -function trojanOut(out: any, inbound: Trojan) { - out.transport = inbound.transport -} - -function vmessOut(out: any, inbound: VMess) { - out.transport = inbound.transport -} diff --git a/frontend/src/views/Inbounds.vue b/frontend/src/views/Inbounds.vue index 9da8321..a5e0c3d 100644 --- a/frontend/src/views/Inbounds.vue +++ b/frontend/src/views/Inbounds.vue @@ -115,7 +115,6 @@ import { Client } from '@/types/clients' import { Link, LinkUtil } from '@/plugins/link' import { i18n } from '@/locales' import { push } from 'notivue' -import { fillData } from '@/plugins/outJson' const appConfig = computed((): Config => { return Data().config @@ -168,11 +167,6 @@ const saveModal = async (data:Inbound) => { }) return } - - // Fill outjson - if (data.out_json){ - fillData(data, data.tls_id > 0 ? tlsConfigs?.value.findLast((t:any) => t.id == data.tls_id) : null) - } let userLinkDiff = [] // Update links diff --git a/frontend/src/views/Tls.vue b/frontend/src/views/Tls.vue index ac427d7..cd172d9 100644 --- a/frontend/src/views/Tls.vue +++ b/frontend/src/views/Tls.vue @@ -91,7 +91,6 @@ import { computed, ref } from 'vue' import { Inbound, inboundWithUsers } from '@/types/inbounds' import { Client } from '@/types/clients' import { Link, LinkUtil } from '@/plugins/link' -import { fillData } from '@/plugins/outJson' const tlsConfigs = computed((): any[] => { return Data().tlsConfigs @@ -134,7 +133,6 @@ const closeModal = () => { modal.value.visible = false } const saveModal = async (data:any) => { - let outJsons = [] let userLinks = [] // New or Edit if (modal.value.id > 0) { @@ -142,11 +140,6 @@ const saveModal = async (data:any) => { if (inboundIds.length > 0) { const tlsInbounds = inboundIds.length == 0 ? [] : await Data().loadInbounds(inboundIds) for (const inbound of tlsInbounds) { - // Fill outjson - if (inbound.out_json) { - fillData(inbound, data) - } - outJsons.push({tag: inbound.tag,out_jsons: inbound.out_json}) // Update links const diff = updateLinks(inbound) diff.forEach((d: any) => { @@ -161,7 +154,7 @@ const saveModal = async (data:any) => { } } - const success = await Data().save("tls", data.id == 0 ? "new" : "edit", data, userLinks.length > 0 ? null: userLinks, outJsons.length > 0 ? null: outJsons) + const success = await Data().save("tls", data.id == 0 ? "new" : "edit", data, userLinks.length > 0 ? null: userLinks) if (success) modal.value.visible = false }