@@ -398,6 +398,7 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation
398398
399399 switch (bufferType)
400400 {
401+ case WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE:
401402 case WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE:
402403 case WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE:
403404 {
@@ -797,7 +798,7 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation
797798 // - an optional HTTP method (defaults to POST if content is present, GET otherwise)
798799 // Return value:
799800 // - the response from the server as a json value
800- WDJ::JsonObject AzureConnection::_SendRequestReturningJson (std::wstring_view uri, const WWH::IHttpContent& content, WWH::HttpMethod method)
801+ WDJ::JsonObject AzureConnection::_SendRequestReturningJson (std::wstring_view uri, const WWH::IHttpContent& content, WWH::HttpMethod method, const Windows::Foundation::Uri referer )
801802 {
802803 if (!method)
803804 {
@@ -810,6 +811,11 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation
810811 auto headers{ request.Headers () };
811812 headers.Accept ().TryParseAdd (L" application/json" );
812813
814+ if (referer)
815+ {
816+ headers.Referer (referer);
817+ }
818+
813819 const auto response{ _httpClient.SendRequestAsync (request).get () };
814820 const auto string{ response.Content ().ReadAsStringAsync ().get () };
815821 const auto jsonResult{ WDJ::JsonObject::Parse (string) };
@@ -974,17 +980,56 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation
974980 auto uri{ fmt::format (L" {}terminals?cols={}&rows={}&version=2019-01-01&shell={}" , _cloudShellUri, _initialCols, _initialRows, shellType) };
975981
976982 WWH::HttpStringContent content{
977- L" " ,
983+ L" {} " ,
978984 WSS::UnicodeEncoding::Utf8,
979985 // LOAD-BEARING. the API returns "'content-type' should be 'application/json' or 'multipart/form-data'"
980986 L" application/json"
981987 };
982988
983- const auto terminalResponse = _SendRequestReturningJson (uri, content);
989+ const auto terminalResponse = _SendRequestReturningJson (uri, content, WWH::HttpMethod::Post (), Windows::Foundation::Uri (_cloudShellUri) );
984990 _terminalID = terminalResponse.GetNamedString (L" id" );
985991
992+ // we have to do some post-handling to get the proper socket endpoint
993+ // the logic here is based on the way the cloud shell team itself does it
994+ winrt::hstring finalSocketUri;
995+ const std::wstring_view wCloudShellUri{ _cloudShellUri };
996+
997+ if (wCloudShellUri.find (L" servicebus" ) == std::wstring::npos)
998+ {
999+ // wCloudShellUri does not contain the word "servicebus", we can just use it to make the final URI
1000+
1001+ // remove the "https" from the cloud shell URI
1002+ const auto uriWithoutProtocol = wCloudShellUri.substr (5 );
1003+
1004+ finalSocketUri = fmt::format (FMT_COMPILE (L" wss{}terminals/{}" ), uriWithoutProtocol, _terminalID);
1005+ }
1006+ else
1007+ {
1008+ // if wCloudShellUri contains the word "servicebus", that means the returned socketUri is of the form
1009+ // wss://ccon-prod-westus-aci-03.servicebus.windows.net/cc-AAAA-AAAAAAAA//aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
1010+ // we need to change it to:
1011+ // wss://ccon-prod-westus-aci-03.servicebus.windows.net/$hc/cc-AAAA-AAAAAAAA/terminals/aaaaaaaaaaaaaaaaaaaaaa
1012+
1013+ const auto socketUri = terminalResponse.GetNamedString (L" socketUri" );
1014+ const std::wstring_view wSocketUri{ socketUri };
1015+
1016+ // get the substring up until the ".net"
1017+ const auto dotNetStart = wSocketUri.find (L" .net" );
1018+ THROW_HR_IF (E_UNEXPECTED, dotNetStart == std::wstring::npos);
1019+ const auto dotNetEnd = dotNetStart + 4 ;
1020+ const auto wSocketUriBody = wSocketUri.substr (0 , dotNetEnd);
1021+
1022+ // get the portion between the ".net" and the "//" (this is the cc-AAAA-AAAAAAAA part)
1023+ const auto lastDoubleSlashPos = wSocketUri.find_last_of (L" //" );
1024+ THROW_HR_IF (E_UNEXPECTED, lastDoubleSlashPos == std::wstring::npos);
1025+ const auto wSocketUriMiddle = wSocketUri.substr (dotNetEnd, lastDoubleSlashPos - (dotNetEnd));
1026+
1027+ // piece together the final uri, adding in the "$hc" and "terminals" where needed
1028+ finalSocketUri = fmt::format (FMT_COMPILE (L" {}/$hc{}terminals/{}" ), wSocketUriBody, wSocketUriMiddle, _terminalID);
1029+ }
1030+
9861031 // Return the uri
987- return terminalResponse. GetNamedString ( L" socketUri " ) ;
1032+ return winrt::hstring{ finalSocketUri } ;
9881033 }
9891034
9901035 // Method description:
0 commit comments