Refactor docker image by applying patch method (#13011)
* first stage try * second try * add ninja * Done * fix
This commit is contained in:
		
							parent
							
								
									7809ca9864
								
							
						
					
					
						commit
						61c2e9c271
					
				
					 4 changed files with 41745 additions and 7 deletions
				
			
		
							
								
								
									
										935
									
								
								docker/llm/serving/xpu/docker/1ccl_for_multi_arc.patch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										935
									
								
								docker/llm/serving/xpu/docker/1ccl_for_multi_arc.patch
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,935 @@
 | 
				
			||||||
 | 
					From d345631f78a2f33ff1ddd7d9908b288eb0afaf46 Mon Sep 17 00:00:00 2001
 | 
				
			||||||
 | 
					From: Huajun Li <huajun.li@.com>
 | 
				
			||||||
 | 
					Date: Fri, 24 May 2024 09:47:26 +0800
 | 
				
			||||||
 | 
					Subject: [PATCH 1/3] allreduce optimization with LL256 for Arc770 dGPU
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					To enable this feature, please set env var:
 | 
				
			||||||
 | 
					    export CCL_DG2_ALLREDUCE=1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Build:
 | 
				
			||||||
 | 
					    1. mkdir build; cd build
 | 
				
			||||||
 | 
					    2. source /opt/intel/oneapi/setvars.sh
 | 
				
			||||||
 | 
					    3. cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp  -DCMAKE_BUILD_TYPE=MinSizeRel
 | 
				
			||||||
 | 
					    4. ninja
 | 
				
			||||||
 | 
					    5. ls -al src/libccl*
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					 src/CMakeLists.txt               |   2 +
 | 
				
			||||||
 | 
					 src/coll/coll.cpp                |  30 +-
 | 
				
			||||||
 | 
					 src/coll/coll_param.cpp          |   1 +
 | 
				
			||||||
 | 
					 src/coll/selection/selection.cpp |   5 +
 | 
				
			||||||
 | 
					 src/common/env/env.cpp           |   1 +
 | 
				
			||||||
 | 
					 src/common/env/env.hpp           |   1 +
 | 
				
			||||||
 | 
					 src/common/env/vars.hpp          |   1 +
 | 
				
			||||||
 | 
					 src/dg2/dg2_allreduce.cpp        | 642 +++++++++++++++++++++++++++++++
 | 
				
			||||||
 | 
					 src/dg2/dg2_allreduce.hpp        |  13 +
 | 
				
			||||||
 | 
					 9 files changed, 693 insertions(+), 3 deletions(-)
 | 
				
			||||||
 | 
					 create mode 100644 src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					 create mode 100644 src/dg2/dg2_allreduce.hpp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
 | 
				
			||||||
 | 
					index c1f2483..82cbf22 100644
 | 
				
			||||||
 | 
					--- a/src/CMakeLists.txt
 | 
				
			||||||
 | 
					+++ b/src/CMakeLists.txt
 | 
				
			||||||
 | 
					@@ -263,6 +263,8 @@ set(CCL_SRC
 | 
				
			||||||
 | 
					     ccl_empty_kvs_attr.cpp
 | 
				
			||||||
 | 
					     ccl_empty_stream.cpp
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+    dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					     ${EXTENSIONS_SRC})
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 if (ENABLE_STUB_BACKEND)
 | 
				
			||||||
 | 
					diff --git a/src/coll/coll.cpp b/src/coll/coll.cpp
 | 
				
			||||||
 | 
					index 9bdb88d..2c89711 100644
 | 
				
			||||||
 | 
					--- a/src/coll/coll.cpp
 | 
				
			||||||
 | 
					+++ b/src/coll/coll.cpp
 | 
				
			||||||
 | 
					@@ -62,6 +62,8 @@
 | 
				
			||||||
 | 
					 #include "sched/sched_timer.hpp"
 | 
				
			||||||
 | 
					 #include "unordered_coll/unordered_coll.hpp"
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+#include "dg2/dg2_allreduce.hpp"
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					 #if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE)
 | 
				
			||||||
 | 
					 #include "coll/algorithms/utils/sycl_selection.hpp"
 | 
				
			||||||
 | 
					 #include "coll/algorithms/allreduce/sycl/allreduce_sycl.hpp"
 | 
				
			||||||
 | 
					@@ -129,6 +131,16 @@ ccl_request* exec_single_rank_coll(const ccl_coll_param& param) {
 | 
				
			||||||
 | 
					     return nullptr;
 | 
				
			||||||
 | 
					 }
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+static ccl_request* ccl_dg2_allreduce_impl(ccl_coll_param& param, const ccl_coll_attr& in_attr)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    dg2_init(param);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    dg2_allreduce(param.send_bufs[0], param.recv_bufs[0],
 | 
				
			||||||
 | 
					+                  param.count, param.dtype, param.reduction, param.comm);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return nullptr;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					 /* param is not const because param.comm can be updated for unordered colls */
 | 
				
			||||||
 | 
					 static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& in_attr) {
 | 
				
			||||||
 | 
					     ccl_coll_attr& attr = const_cast<ccl_coll_attr&>(in_attr);
 | 
				
			||||||
 | 
					@@ -1181,12 +1193,24 @@ ccl_request* ccl_allreduce_impl(const void* send_buf,
 | 
				
			||||||
 | 
					                                 ccl_comm* comm,
 | 
				
			||||||
 | 
					                                 const ccl_stream* stream,
 | 
				
			||||||
 | 
					                                 const std::vector<ccl::event>& deps) {
 | 
				
			||||||
 | 
					+    ccl_request *req;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					     ccl_coll_param param = ccl_coll_param::create_allreduce_param(
 | 
				
			||||||
 | 
					         send_buf, recv_buf, count, dtype, reduction, attr, comm, stream, deps);
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-    auto req = ccl_coll_create(param, attr);
 | 
				
			||||||
 | 
					-    LOG_DEBUG(
 | 
				
			||||||
 | 
					-        "coll ", ccl_coll_type_to_str(param.ctype), " created, req ", stream, " count ", count);
 | 
				
			||||||
 | 
					+    std::shared_ptr<atl_base_comm> atl_comm = comm->get_node_comm().get()->get_atl_comm();
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (ccl::global_data::env().enable_dg2_allreduce
 | 
				
			||||||
 | 
					+        && (atl_comm->get_size() >= 2) && (atl_comm->get_size() <= DG2_NUM)
 | 
				
			||||||
 | 
					+        && ((dtype == ccl::datatype::float16) || (dtype == ccl::datatype::float32) || (dtype == ccl::datatype::int32))) {
 | 
				
			||||||
 | 
					+        req = ccl_dg2_allreduce_impl(param, attr);
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+    else {
 | 
				
			||||||
 | 
					+        req = ccl_coll_create(param, attr);
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req, " count ", count);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					     return req;
 | 
				
			||||||
 | 
					 }
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					diff --git a/src/coll/coll_param.cpp b/src/coll/coll_param.cpp
 | 
				
			||||||
 | 
					index 7ab7b57..e22fc1d 100644
 | 
				
			||||||
 | 
					--- a/src/coll/coll_param.cpp
 | 
				
			||||||
 | 
					+++ b/src/coll/coll_param.cpp
 | 
				
			||||||
 | 
					@@ -579,6 +579,7 @@ ccl_coll_param ccl_coll_param::create_allreduce_param(const void* send_buf,
 | 
				
			||||||
 | 
					     ccl_coll_param param{};
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     param.ctype = ccl_coll_allreduce;
 | 
				
			||||||
 | 
					+    param.count = count;
 | 
				
			||||||
 | 
					     param.send_bufs.push_back((void*)send_buf);
 | 
				
			||||||
 | 
					     param.send_counts.push_back(count);
 | 
				
			||||||
 | 
					     param.recv_bufs.push_back(recv_buf);
 | 
				
			||||||
 | 
					diff --git a/src/coll/selection/selection.cpp b/src/coll/selection/selection.cpp
 | 
				
			||||||
 | 
					index 5deae74..f4d9302 100644
 | 
				
			||||||
 | 
					--- a/src/coll/selection/selection.cpp
 | 
				
			||||||
 | 
					+++ b/src/coll/selection/selection.cpp
 | 
				
			||||||
 | 
					@@ -393,6 +393,11 @@ bool ccl_can_use_topo_algo(const ccl_selector_param& param) {
 | 
				
			||||||
 | 
					             " is not supported for family1");
 | 
				
			||||||
 | 
					     }
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+    if (ccl::global_data::env().enable_dg2_allreduce) {
 | 
				
			||||||
 | 
					+        LOG_DEBUG("topo algorithm is not supported by DG2");
 | 
				
			||||||
 | 
					+        return false;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					     if (checkers::is_unknown_device_family(param)) {
 | 
				
			||||||
 | 
					         LOG_WARN("Applying topo algorithm, but device family is not recognized");
 | 
				
			||||||
 | 
					 #ifndef CCL_BF16_GPU_TRUNCATE
 | 
				
			||||||
 | 
					diff --git a/src/common/env/env.cpp b/src/common/env/env.cpp
 | 
				
			||||||
 | 
					index 11413cf..b00641c 100644
 | 
				
			||||||
 | 
					--- a/src/common/env/env.cpp
 | 
				
			||||||
 | 
					+++ b/src/common/env/env.cpp
 | 
				
			||||||
 | 
					@@ -468,6 +468,7 @@ void env_data::parse() {
 | 
				
			||||||
 | 
					     }
 | 
				
			||||||
 | 
					     p.env_2_enum(CCL_STAGING_BUFFER, staging_buffer_names, staging_buffer);
 | 
				
			||||||
 | 
					     p.env_2_type(CCL_OP_SYNC, enable_op_sync);
 | 
				
			||||||
 | 
					+    p.env_2_type(CCL_DG2_ALLREDUCE, enable_dg2_allreduce);
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     p.env_2_type(CCL_CHUNK_COUNT, chunk_count);
 | 
				
			||||||
 | 
					     CCL_THROW_IF_NOT(chunk_count >= 1, "incorrect ", CCL_CHUNK_COUNT, " ", chunk_count);
 | 
				
			||||||
 | 
					diff --git a/src/common/env/env.hpp b/src/common/env/env.hpp
 | 
				
			||||||
 | 
					index baff33f..538d8e5 100644
 | 
				
			||||||
 | 
					--- a/src/common/env/env.hpp
 | 
				
			||||||
 | 
					+++ b/src/common/env/env.hpp
 | 
				
			||||||
 | 
					@@ -177,6 +177,7 @@ public:
 | 
				
			||||||
 | 
					     bool enable_strict_order;
 | 
				
			||||||
 | 
					     ccl_staging_buffer staging_buffer;
 | 
				
			||||||
 | 
					     bool enable_op_sync;
 | 
				
			||||||
 | 
					+    int enable_dg2_allreduce;
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     size_t chunk_count;
 | 
				
			||||||
 | 
					     size_t min_chunk_size;
 | 
				
			||||||
 | 
					diff --git a/src/common/env/vars.hpp b/src/common/env/vars.hpp
 | 
				
			||||||
 | 
					index 73dcf77..84ab518 100644
 | 
				
			||||||
 | 
					--- a/src/common/env/vars.hpp
 | 
				
			||||||
 | 
					+++ b/src/common/env/vars.hpp
 | 
				
			||||||
 | 
					@@ -579,6 +579,7 @@ constexpr const char* CCL_BUFFER_CACHE = "CCL_BUFFER_CACHE";
 | 
				
			||||||
 | 
					 constexpr const char* CCL_STRICT_ORDER = "CCL_STRICT_ORDER";
 | 
				
			||||||
 | 
					 constexpr const char* CCL_STAGING_BUFFER = "CCL_STAGING_BUFFER";
 | 
				
			||||||
 | 
					 constexpr const char* CCL_OP_SYNC = "CCL_OP_SYNC";
 | 
				
			||||||
 | 
					+constexpr const char* CCL_DG2_ALLREDUCE = "CCL_DG2_ALLREDUCE";
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 constexpr const char* CCL_CHUNK_COUNT = "CCL_CHUNK_COUNT";
 | 
				
			||||||
 | 
					 constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE";
 | 
				
			||||||
 | 
					diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					new file mode 100644
 | 
				
			||||||
 | 
					index 0000000..15ace74
 | 
				
			||||||
 | 
					--- /dev/null
 | 
				
			||||||
 | 
					+++ b/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					@@ -0,0 +1,642 @@
 | 
				
			||||||
 | 
					+#include <fcntl.h>
 | 
				
			||||||
 | 
					+#include <unistd.h>
 | 
				
			||||||
 | 
					+#include <sys/un.h>
 | 
				
			||||||
 | 
					+#include <sys/stat.h>
 | 
				
			||||||
 | 
					+#include <sys/types.h>
 | 
				
			||||||
 | 
					+#include <sys/ioctl.h>
 | 
				
			||||||
 | 
					+#include <sys/socket.h>
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#include <drm/drm.h>
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#include <mpi.h>
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#include <vector>
 | 
				
			||||||
 | 
					+#include <sstream>
 | 
				
			||||||
 | 
					+#include <iostream>
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#include <ext/intel/esimd.hpp>
 | 
				
			||||||
 | 
					+//#include <level_zero/ze_api.h>
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#include "oneapi/ccl.hpp"
 | 
				
			||||||
 | 
					+#include "common/global/global.hpp"
 | 
				
			||||||
 | 
					+#include "common/utils/exchange_utils.hpp"
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#include "dg2_allreduce.hpp"
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+using namespace std;
 | 
				
			||||||
 | 
					+using namespace sycl;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static sycl::queue q;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static int world_size = 0;
 | 
				
			||||||
 | 
					+static int world_rank = 0;
 | 
				
			||||||
 | 
					+static void *host_bufs[DG2_NUM];       /* host shared buf */
 | 
				
			||||||
 | 
					+static void *peer_bufs[DG2_NUM];       /* shared buf on peer side */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define FLAG_SIZE (32)                 /* FIXME: 256 bits */
 | 
				
			||||||
 | 
					+#define SOCK_PATH "/tmp/ipc_sock"
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+uint16_t pattern_counter = 0xa770;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+typedef uint32_t pattern_t;
 | 
				
			||||||
 | 
					+using message_t = sycl::vec<uint32_t, 4>;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define SG_SZ      (16)                        /* Arc770: Subgroup Sizes Supported: 8;16;32, while 8 threads per EU */
 | 
				
			||||||
 | 
					+#define LS_SZ      (sizeof(message_t))         /* load/store byte size per work-item */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define __LscLoadUnCached(var, addr)   \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_load.ugm.uc.uc   (M1, 16)  %0:d64  flat[%1]:a64" : "=rw"(var) : "rw"(addr) : "memory")
 | 
				
			||||||
 | 
					+#define __LscLoadCached(var, addr)   \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_load.ugm.ca.ca   (M1, 16)  %0:d64  flat[%1]:a64" : "=rw"(var) : "rw"(addr) : "memory")
 | 
				
			||||||
 | 
					+#define __LscLoadUnCachedVec(var, addr)   \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_load.ugm.uc.uc   (M1, 16)  %0:d32x4  flat[%1]:a64" : "=rw"(reinterpret_cast<typename message_t::vector_t &>(var)) : "rw"(addr) : "memory")
 | 
				
			||||||
 | 
					+#define __LscLoadCachedVec(var, addr)   \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_load.ugm.ca.ca   (M1, 16)  %0:d32x4  flat[%1]:a64" : "=rw"(reinterpret_cast<typename message_t::vector_t &>(var)) : "rw"(addr) : "memory")
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define __LscStoreUnCached(addr, var)  \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_store.ugm.uc.uc  (M1, 16)  flat[%0]:a64  %1:d64" : : "rw"(addr), "rw"(var) : "memory")
 | 
				
			||||||
 | 
					+#define __LscStoreCached(addr, var)  \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_store.ugm.ca.ca  (M1, 16)  flat[%0]:a64  %1:d64" : : "rw"(addr), "rw"(var) : "memory")
 | 
				
			||||||
 | 
					+#define __LscStoreUnCachedVec(addr, var)  \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_store.ugm.uc.uc  (M1, 16)  flat[%0]:a64  %1:d32x4" : : "rw"(addr), "rw"(reinterpret_cast<typename message_t::vector_t &>(var)) : "memory")
 | 
				
			||||||
 | 
					+#define __LscStoreCachedVec(addr, var)  \
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("lsc_store.ugm.ca.ca  (M1, 16)  flat[%0]:a64  %1:d32x4" : : "rw"(addr), "rw"(reinterpret_cast<typename message_t::vector_t &>(var)) : "memory")
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define LscLoadCached     __LscLoadCachedVec
 | 
				
			||||||
 | 
					+#define LscLoadUnCached   __LscLoadUnCachedVec
 | 
				
			||||||
 | 
					+#define LscStoreCached    __LscStoreCachedVec
 | 
				
			||||||
 | 
					+#define LscStoreUnCached  __LscStoreUnCachedVec
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+inline int get_fd_from_handle(const ze_ipc_mem_handle_t& handle)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    return *(reinterpret_cast<const int*>(handle.data));
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+inline ze_ipc_mem_handle_t get_handle_from_fd(int fd)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    ze_ipc_mem_handle_t handle{};
 | 
				
			||||||
 | 
					+    *(reinterpret_cast<int*>(handle.data)) = fd;
 | 
				
			||||||
 | 
					+    return handle;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static int srv_sock(char *sock_path)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    int fd;
 | 
				
			||||||
 | 
					+    struct sockaddr_un sk;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    sk.sun_family = AF_UNIX;
 | 
				
			||||||
 | 
					+    strncpy(sk.sun_path, sock_path, sizeof(sk.sun_path)-1);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    unlink(sock_path);
 | 
				
			||||||
 | 
					+    fd = socket(PF_UNIX, SOCK_STREAM, 0);
 | 
				
			||||||
 | 
					+    if (fd < 0) {
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    mode_t prev_umask = umask(0);
 | 
				
			||||||
 | 
					+    if (fchmod(fd, 0666) < 0) {
 | 
				
			||||||
 | 
					+        close(fd);
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+    umask(prev_umask);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (bind(fd, (struct sockaddr *)&sk, sizeof(sk)) == -1) {
 | 
				
			||||||
 | 
					+        close(fd);
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (fcntl(fd, F_SETFL, O_NONBLOCK) == -1) {
 | 
				
			||||||
 | 
					+        close(fd);
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (listen(fd, 5) == -1) {
 | 
				
			||||||
 | 
					+        close(fd);
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return fd;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static int cli_sock(char *sock_path)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    int fd;
 | 
				
			||||||
 | 
					+    struct sockaddr_un sk;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    sk.sun_family = AF_UNIX;
 | 
				
			||||||
 | 
					+    strncpy(sk.sun_path, sock_path, sizeof(sk.sun_path)-1);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    fd = socket(PF_UNIX, SOCK_STREAM, 0);
 | 
				
			||||||
 | 
					+    if (fd < 0) {
 | 
				
			||||||
 | 
					+        fprintf(stderr, "Failed to create socket(%d), sock path: %s\n", errno, sock_path);
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (connect(fd, (struct sockaddr *)&sk, sizeof(sk)) == -1) {
 | 
				
			||||||
 | 
					+        //fprintf(stderr, "Failed to connect socket(%d)\n", errno);
 | 
				
			||||||
 | 
					+        close(fd);
 | 
				
			||||||
 | 
					+        return -1;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return fd;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static void *thread_func(void *arg)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    fd_set fds;
 | 
				
			||||||
 | 
					+    int count = 0;
 | 
				
			||||||
 | 
					+    char sock_path[64];
 | 
				
			||||||
 | 
					+    int peer_buf_fd = 0;
 | 
				
			||||||
 | 
					+    int rank = *(int *)arg;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, rank, 0xa770);
 | 
				
			||||||
 | 
					+    int srv_fd = srv_sock(sock_path);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    //std::cout << "-----> srv_fd of " << sock_path << " : " << srv_fd << "\n";
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    auto sycl_context = q.get_context();
 | 
				
			||||||
 | 
					+    auto sycl_device = q.get_device();
 | 
				
			||||||
 | 
					+    ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
 | 
				
			||||||
 | 
					+    ze_device_handle_t  ze_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    FD_ZERO(&fds);
 | 
				
			||||||
 | 
					+    FD_SET(srv_fd, &fds);
 | 
				
			||||||
 | 
					+    while (++count < world_size) {
 | 
				
			||||||
 | 
					+        int ret = select(srv_fd + 1, &fds, NULL, NULL, NULL);
 | 
				
			||||||
 | 
					+        switch (ret) {
 | 
				
			||||||
 | 
					+        case 1:
 | 
				
			||||||
 | 
					+            {
 | 
				
			||||||
 | 
					+                int peer_rank;
 | 
				
			||||||
 | 
					+                void *peer_buf;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                int conn_fd = accept(srv_fd, NULL, 0);
 | 
				
			||||||
 | 
					+                ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
 | 
				
			||||||
 | 
					+                zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED /* cached allocation */, &peer_buf);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                peer_bufs[peer_rank] = peer_buf;
 | 
				
			||||||
 | 
					+                //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                if (conn_fd > 0) close(conn_fd);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                break;
 | 
				
			||||||
 | 
					+            }
 | 
				
			||||||
 | 
					+        case 0:
 | 
				
			||||||
 | 
					+        case -1:
 | 
				
			||||||
 | 
					+            std::cout << "srv_fd select() failed" << "\n";
 | 
				
			||||||
 | 
					+            break;
 | 
				
			||||||
 | 
					+        default:
 | 
				
			||||||
 | 
					+            break;
 | 
				
			||||||
 | 
					+        }
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (srv_fd > 0) {
 | 
				
			||||||
 | 
					+        close(srv_fd);
 | 
				
			||||||
 | 
					+        unlink(sock_path);
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return nullptr;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+void create_shared_buf(void *send_buf, void *recv_buf, size_t byte_count)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    printf("-----> current rank: %d, world size: %d, byte_count: %lu\n", world_rank, world_size, byte_count);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    pthread_t tid;
 | 
				
			||||||
 | 
					+    char sock_path[64];
 | 
				
			||||||
 | 
					+    void *host_buf = nullptr;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    auto sycl_context = q.get_context();
 | 
				
			||||||
 | 
					+    ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    /* thread to accept connection request from peers */
 | 
				
			||||||
 | 
					+    pthread_create(&tid, nullptr, thread_func, &world_rank);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    size_t buf_size = LL256_BUF_SIZE;
 | 
				
			||||||
 | 
					+    host_buf = sycl::aligned_alloc_device(getpagesize(), buf_size, q);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    host_bufs[world_rank] = host_buf;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    //printf("-------------> rank: %d, host_bufs[%d]: %p\n", world_rank, world_rank, host_bufs[world_rank]);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    for (int i = 0; i < world_size; i++) {
 | 
				
			||||||
 | 
					+        if (i != world_rank) {
 | 
				
			||||||
 | 
					+            int cli_fd = -1;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            ze_ipc_mem_handle_t ipc_handle_host_buf;
 | 
				
			||||||
 | 
					+            zeMemGetIpcHandle(ze_context, host_buf, &ipc_handle_host_buf);
 | 
				
			||||||
 | 
					+            int host_buf_fd = get_fd_from_handle(ipc_handle_host_buf);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, i, 0xa770);
 | 
				
			||||||
 | 
					+            while (cli_fd < 0) cli_fd = cli_sock(sock_path);
 | 
				
			||||||
 | 
					+            //std::cout << "<----- cli_fd of " << sock_path << " : " << cli_fd << "\n";
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            ccl::utils::sendmsg_fd(cli_fd, host_buf_fd, &world_rank, sizeof(world_rank));
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            if (cli_fd > 0) close(cli_fd);
 | 
				
			||||||
 | 
					+        }
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    pthread_join(tid, nullptr);
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+void dg2_init(ccl_coll_param ¶m)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    size_t byte_count = param.count * param.dtype.size();
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    ccl_stream *stream = param.stream;
 | 
				
			||||||
 | 
					+    q = stream->get_native_stream();
 | 
				
			||||||
 | 
					+    // init recv_buf in advance to warm up GPU
 | 
				
			||||||
 | 
					+    if (param.send_bufs[0] != param.recv_bufs[0])
 | 
				
			||||||
 | 
					+        q.memcpy(param.recv_bufs[0], param.send_bufs[0], byte_count);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    /* init already */
 | 
				
			||||||
 | 
					+    if (world_size != 0)
 | 
				
			||||||
 | 
					+        return;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    ccl_comm *comm = param.comm;
 | 
				
			||||||
 | 
					+    std::shared_ptr<atl_base_comm> atl_comm = comm->get_node_comm().get()->get_atl_comm();
 | 
				
			||||||
 | 
					+    world_size = atl_comm->get_size();
 | 
				
			||||||
 | 
					+    world_rank = atl_comm->get_rank();
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    create_shared_buf(param.send_bufs[0], param.recv_bufs[0], byte_count);
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+template <typename T>
 | 
				
			||||||
 | 
					+static inline message_t _sum(message_t dst, message_t src)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    using math_t = sycl::vec<T, sizeof(message_t) / sizeof(T)>;
 | 
				
			||||||
 | 
					+    return sycl::bit_cast<message_t>(sycl::bit_cast<math_t>(dst) + sycl::bit_cast<math_t>(src));
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
 | 
				
			||||||
 | 
					+static inline message_t sum(message_t dst, message_t src, const ccl_datatype& dtype)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    message_t data;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    switch (dtype.idx()) {
 | 
				
			||||||
 | 
					+    case ccl::datatype::float16:
 | 
				
			||||||
 | 
					+        data = _sum<sycl::half>(dst, src);
 | 
				
			||||||
 | 
					+        break;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    case ccl::datatype::float32:
 | 
				
			||||||
 | 
					+        data = _sum<float>(dst, src);
 | 
				
			||||||
 | 
					+        break;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    case ccl::datatype::int32:
 | 
				
			||||||
 | 
					+        data = _sum<int32_t>(dst, src);
 | 
				
			||||||
 | 
					+        break;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    default:
 | 
				
			||||||
 | 
					+        /* following code will hurt performance */
 | 
				
			||||||
 | 
					+        //sycl::ext::oneapi::experimental::printf("Unknow dtype!\n");
 | 
				
			||||||
 | 
					+        break;
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return data;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void sync_data(char *src, message_t &data, int lid, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    size_t sz = sizeof(message_t);
 | 
				
			||||||
 | 
					+    auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    do {
 | 
				
			||||||
 | 
					+        LscLoadUnCached(data, src + lid * sz);
 | 
				
			||||||
 | 
					+    } while (sycl::any_of_group(sg, ((lid ==  3) && (data[3] != pattern)) ||
 | 
				
			||||||
 | 
					+                                    ((lid ==  7) && (data[3] != pattern)) ||
 | 
				
			||||||
 | 
					+                                    ((lid == 11) && (data[3] != pattern)) ||
 | 
				
			||||||
 | 
					+                                    ((lid == 15) && (data[3] != pattern))));
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void shuffle_data(message_t &data)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("mov (M1, 1) %0(1, 7)<1> %0(6, 3)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(3, 7)<1> %0(6, 7)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(5, 7)<1> %0(7, 3)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         : "+rw"(reinterpret_cast<typename message_t::vector_t &>(data))
 | 
				
			||||||
 | 
					+                         : );
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void insert_pattern(message_t &data, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("mov (M1, 1) %0(6, 3)<1> %1(0, 0)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(6, 7)<1> %1(0, 0)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(7, 3)<1> %1(0, 0)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(7, 7)<1> %1(0, 0)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         : "+rw"(reinterpret_cast<typename message_t::vector_t &>(data))
 | 
				
			||||||
 | 
					+                         : "rw"(pattern));
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void restore_data(message_t &data)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    __asm__ __volatile__("mov (M1, 1) %0(6, 3)<1> %0(1, 7)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(6, 7)<1> %0(3, 7)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         "mov (M1, 1) %0(7, 3)<1> %0(5, 7)<0;1,0>\n"
 | 
				
			||||||
 | 
					+                         : "+rw"(reinterpret_cast<typename message_t::vector_t &>(data))
 | 
				
			||||||
 | 
					+                         : );
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+#endif
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void send(char *next, char *src, int lid, int req_workitems,
 | 
				
			||||||
 | 
					+                        const ccl_datatype& dtype, int rank, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
 | 
				
			||||||
 | 
					+    message_t data;
 | 
				
			||||||
 | 
					+    int sz = sizeof(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    LscLoadCached(data, src + lid * sz);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    shuffle_data(data);
 | 
				
			||||||
 | 
					+    insert_pattern(data, pattern);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    LscStoreUnCached(next + lid * sz, data);
 | 
				
			||||||
 | 
					+    #endif
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void recv_reduce_send(char *dst, char *next, char *src, int lid, int req_workitems,
 | 
				
			||||||
 | 
					+                                    const ccl_datatype& dtype, int rank, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
 | 
				
			||||||
 | 
					+    message_t data;
 | 
				
			||||||
 | 
					+    int sz = sizeof(data);
 | 
				
			||||||
 | 
					+    message_t *dst_buf = (message_t *)dst;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    sync_data(src, data, lid, pattern);
 | 
				
			||||||
 | 
					+    restore_data(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    data = sum(dst_buf[lid], data, dtype);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    shuffle_data(data);
 | 
				
			||||||
 | 
					+    insert_pattern(data, pattern);
 | 
				
			||||||
 | 
					+    LscStoreUnCached(next + lid * sz, data);
 | 
				
			||||||
 | 
					+    #endif
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void recv_reduce_copy_send(char *dst, char *next, char *src, int lid, int req_workitems,
 | 
				
			||||||
 | 
					+                                         const ccl_datatype& dtype, int rank, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
 | 
				
			||||||
 | 
					+    message_t data;
 | 
				
			||||||
 | 
					+    int sz = sizeof(data);
 | 
				
			||||||
 | 
					+    message_t *dst_buf = (message_t *)dst;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    sync_data(src, data, lid, pattern);
 | 
				
			||||||
 | 
					+    restore_data(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    data = sum(dst_buf[lid], data, dtype);
 | 
				
			||||||
 | 
					+    if (lid < req_workitems)
 | 
				
			||||||
 | 
					+        LscStoreUnCached(dst + lid * sz, data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    shuffle_data(data);
 | 
				
			||||||
 | 
					+    insert_pattern(data, pattern);
 | 
				
			||||||
 | 
					+    LscStoreUnCached(next + lid * sz, data);
 | 
				
			||||||
 | 
					+    #endif
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void recv_copy_send(char *dst, char *next, char *src, int lid, int req_workitems,
 | 
				
			||||||
 | 
					+                                  const ccl_datatype& dtype, int rank, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
 | 
				
			||||||
 | 
					+    message_t data;
 | 
				
			||||||
 | 
					+    int sz = sizeof(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    sync_data(src, data, lid, pattern);
 | 
				
			||||||
 | 
					+    LscStoreUnCached(next + lid * sz, data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    restore_data(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (lid < req_workitems)
 | 
				
			||||||
 | 
					+        LscStoreUnCached(dst + lid * sz, data);
 | 
				
			||||||
 | 
					+    #endif
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+static inline void recv(char *dst, char *src, int lid, int req_workitems,
 | 
				
			||||||
 | 
					+                        const ccl_datatype& dtype, int rank, pattern_t pattern)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
 | 
				
			||||||
 | 
					+    message_t data;
 | 
				
			||||||
 | 
					+    int sz = sizeof(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    /* copy reduced data from peer */
 | 
				
			||||||
 | 
					+    sync_data(src, data, lid, pattern);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    restore_data(data);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    if (lid < req_workitems)
 | 
				
			||||||
 | 
					+        LscStoreUnCached(dst + lid * sz, data);
 | 
				
			||||||
 | 
					+    #endif
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
 | 
				
			||||||
 | 
					+                               const ccl_datatype &dtype, ccl::reduction reduction, ccl_comm *comm)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    ccl::event ret;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    //std::cout << "enter " << __func__ << ", rank: " << world_rank <<  ", count: " << count << std::endl;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    size_t dt_sz = dtype.size();
 | 
				
			||||||
 | 
					+    char *recv_buf = static_cast<char *>(dst);
 | 
				
			||||||
 | 
					+    char *send_buf = static_cast<char *>(const_cast<void*>(src));
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    /*
 | 
				
			||||||
 | 
					+     * Intel(R) Arc(TM) A770 Graphics:
 | 
				
			||||||
 | 
					+     *   Number Of Slices:                       1
 | 
				
			||||||
 | 
					+     *   Number Of Subslices Per Slice:          32
 | 
				
			||||||
 | 
					+     *   Number Of EU Per Subslice:              16
 | 
				
			||||||
 | 
					+     *   Number Of Threads Per EU:               8
 | 
				
			||||||
 | 
					+     *   Total EU Count:                         512
 | 
				
			||||||
 | 
					+     *   Physical EU SIMD Width:                 8
 | 
				
			||||||
 | 
					+     */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    /* 64-byte load/store granularity to HBM, Maximum 128-byte payload can be used by EU store */
 | 
				
			||||||
 | 
					+    /* Arc770: Subgroup Sizes Supported: 8;16;32, while 8 threads per EU */
 | 
				
			||||||
 | 
					+    size_t sg_sz = SG_SZ;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    size_t l_sz = 1 * sg_sz;
 | 
				
			||||||
 | 
					+    size_t g_sz = 512 * l_sz;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    /* To avoid pattern not changed when "iters" is 1 */
 | 
				
			||||||
 | 
					+    pattern_t pattern_prefix = ++pattern_counter << 16;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    q.submit([&](auto& h) {
 | 
				
			||||||
 | 
					+        using namespace sycl::ext::intel::experimental::esimd;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        int local_world_rank = world_rank;
 | 
				
			||||||
 | 
					+        int local_world_size = world_size;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        int next_rank = (local_world_rank + 1) % local_world_size;
 | 
				
			||||||
 | 
					+        char *local_host_buf = (char *)host_bufs[local_world_rank];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        char *local_peer_bufs[DG2_NUM];
 | 
				
			||||||
 | 
					+        for (int i = 0; i < world_size; i++)
 | 
				
			||||||
 | 
					+            local_peer_bufs[i] = (char *)peer_bufs[i];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        /*
 | 
				
			||||||
 | 
					+         * In a single subgroup:
 | 
				
			||||||
 | 
					+         *   a> 1 dedicated work-item to manage a LS_SZ-byte pattern.
 | 
				
			||||||
 | 
					+         *   b> other work-items to process data, and each of them handle a LS_SZ-byte data.
 | 
				
			||||||
 | 
					+         */
 | 
				
			||||||
 | 
					+        auto default_subgroup_capacity = sg_sz * LS_SZ;  /* bytes: data and pattern  processed by 1 subgroup */
 | 
				
			||||||
 | 
					+        auto default_workgroup_capacity = l_sz * LS_SZ;  /* bytes: data and patterns processed by 1 workgroup */
 | 
				
			||||||
 | 
					+        //auto default_total_capacity = g_sz * LS_SZ;      /* bytes: data and patterns processed by all workgroups in 1 iteration */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        /* In a single workgroup, the available work-items to process data, excluding work-items for patterns */
 | 
				
			||||||
 | 
					+        auto workgroup_available_items = l_sz - (l_sz / sg_sz);
 | 
				
			||||||
 | 
					+        auto total_available_items = (g_sz / l_sz) * workgroup_available_items;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        auto subgroup_capacity = LS_SZ * (sg_sz - 1);                  /* bytes: data processed by 1 subgroup */
 | 
				
			||||||
 | 
					+        auto workgroup_capacity = LS_SZ * workgroup_available_items;   /* bytes: data processed by 1 workgroup */
 | 
				
			||||||
 | 
					+        auto total_capacity = (g_sz / l_sz) * workgroup_capacity;      /* bytes: data processed by all workgroups in 1 iteration */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        /* div up */
 | 
				
			||||||
 | 
					+        int iters = (count * dt_sz + (local_world_size * total_available_items * LS_SZ - 1)) / (local_world_size * total_available_items * LS_SZ);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        //sycl::ext::oneapi::experimental::printf("------> rank: %d, group num: %ld, loop count: %zu\n", local_world_rank, g_sz / l_sz, iters);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        h.parallel_for(sycl::nd_range<1>(g_sz, l_sz), [=] (sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(SG_SZ)]] {
 | 
				
			||||||
 | 
					+            int idx = 0;
 | 
				
			||||||
 | 
					+            size_t offset = 0;
 | 
				
			||||||
 | 
					+            size_t offset_with_pattern = 0;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            auto group_id = item.get_group_linear_id();
 | 
				
			||||||
 | 
					+            auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
 | 
				
			||||||
 | 
					+            auto sg_id = sg.get_group_id()[0];
 | 
				
			||||||
 | 
					+            auto sg_lid = sg.get_local_id()[0];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            for (int i = 0; i < iters; i++) {
 | 
				
			||||||
 | 
					+                pattern_t pattern = pattern_prefix + i;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                auto base = local_world_size * (i        * total_capacity  +
 | 
				
			||||||
 | 
					+                                                group_id * workgroup_capacity  +
 | 
				
			||||||
 | 
					+                                                sg_id    * subgroup_capacity);
 | 
				
			||||||
 | 
					+                auto base_with_pattern = local_world_size * (/* i        * default_total_capacity  + */
 | 
				
			||||||
 | 
					+                                                             group_id * default_workgroup_capacity  +
 | 
				
			||||||
 | 
					+                                                             sg_id    * default_subgroup_capacity);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                auto finished = i * total_capacity * local_world_size;   /* bytes */
 | 
				
			||||||
 | 
					+                auto unreduced = count * dt_sz - finished;               /* bytes */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                auto req_workitems = sg_sz - 1;                /* required work-items exclude 1 work-item for pattern */
 | 
				
			||||||
 | 
					+                auto chunk_sz = req_workitems * LS_SZ;         /* LS_SZ bytes per work-item */
 | 
				
			||||||
 | 
					+                auto chunk_with_pattern = sg_sz * LS_SZ;       /* aligned to 256B */
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                /* items will be assigned to each rank */
 | 
				
			||||||
 | 
					+                auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
 | 
				
			||||||
 | 
					+                auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
 | 
				
			||||||
 | 
					+                auto req_subgroups = 0;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                if (req_workgroups >= g_sz/l_sz) {
 | 
				
			||||||
 | 
					+                    req_workgroups = g_sz/l_sz;
 | 
				
			||||||
 | 
					+                } else {
 | 
				
			||||||
 | 
					+                    if (group_id == (req_workgroups - 1)) {
 | 
				
			||||||
 | 
					+                        req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
 | 
				
			||||||
 | 
					+                        /* Note:  req_subgroups % (l_sz/sg_sz) might be 0 */
 | 
				
			||||||
 | 
					+                        if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
 | 
				
			||||||
 | 
					+                            if ((per_rank_items % (sg_sz - 1)) != 0) {
 | 
				
			||||||
 | 
					+                                /* FIXME: */
 | 
				
			||||||
 | 
					+                                req_workitems = per_rank_items % (sg_sz - 1);
 | 
				
			||||||
 | 
					+                                chunk_sz = req_workitems * LS_SZ;    /* LS_SZ bytes per work-item */
 | 
				
			||||||
 | 
					+                            }
 | 
				
			||||||
 | 
					+                        }
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+                }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                if (group_id < req_workgroups) {
 | 
				
			||||||
 | 
					+                    // step 1: push data to next GPU
 | 
				
			||||||
 | 
					+                    {
 | 
				
			||||||
 | 
					+                        offset = base + local_world_rank * chunk_sz;
 | 
				
			||||||
 | 
					+                        offset_with_pattern = base_with_pattern + local_world_rank * chunk_with_pattern;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        char *next = local_peer_bufs[next_rank];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        send(next + offset_with_pattern, send_buf + offset, sg_lid, req_workitems, dtype, local_world_rank, pattern);
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                    // step 2: reduce and copy to next GPU
 | 
				
			||||||
 | 
					+                    for (int j = 2; j < local_world_size; j++) {
 | 
				
			||||||
 | 
					+                        idx = (local_world_rank + local_world_size + 1 - j) % local_world_size;
 | 
				
			||||||
 | 
					+                        offset = base + idx * chunk_sz;
 | 
				
			||||||
 | 
					+                        offset_with_pattern = base_with_pattern + idx * chunk_with_pattern;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        char *src = local_host_buf;
 | 
				
			||||||
 | 
					+                        char *next = local_peer_bufs[next_rank];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        recv_reduce_send(recv_buf + offset, next + offset_with_pattern, src + offset_with_pattern,
 | 
				
			||||||
 | 
					+                                         sg_lid, req_workitems, dtype, local_world_rank, pattern);
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                    // step 3: reduce this buffer and data, which will produce the final
 | 
				
			||||||
 | 
					+                    // result that we store in this data and push to the next GPU
 | 
				
			||||||
 | 
					+                    {
 | 
				
			||||||
 | 
					+                        idx = (local_world_rank + 1) % local_world_size;
 | 
				
			||||||
 | 
					+                        offset = base + idx * chunk_sz;
 | 
				
			||||||
 | 
					+                        offset_with_pattern = base_with_pattern + idx * chunk_with_pattern;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        char *src = local_host_buf;
 | 
				
			||||||
 | 
					+                        char *next = local_peer_bufs[next_rank];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        recv_reduce_copy_send(recv_buf + offset, next + GATHER_BUF_OFFSET + offset_with_pattern, src + offset_with_pattern,
 | 
				
			||||||
 | 
					+                                              sg_lid, req_workitems, dtype, local_world_rank, pattern);
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                    // step 4: copy to next GPU
 | 
				
			||||||
 | 
					+                    for (int j = 1; j < local_world_size - 1; ++j) {
 | 
				
			||||||
 | 
					+                        idx = (local_world_rank + local_world_size + 1 - j) % local_world_size;
 | 
				
			||||||
 | 
					+                        offset = base + idx * chunk_sz;
 | 
				
			||||||
 | 
					+                        offset_with_pattern = GATHER_BUF_OFFSET + base_with_pattern + idx * chunk_with_pattern;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        char *src = local_host_buf;
 | 
				
			||||||
 | 
					+                        char *next = local_peer_bufs[next_rank];
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        recv_copy_send(recv_buf + offset, next + offset_with_pattern, src + offset_with_pattern,
 | 
				
			||||||
 | 
					+                                       sg_lid, req_workitems, dtype, local_world_rank, pattern);
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                    // step 5: Make final copy from buffer to dest
 | 
				
			||||||
 | 
					+                    {
 | 
				
			||||||
 | 
					+                        idx = (local_world_rank + 2) % local_world_size;
 | 
				
			||||||
 | 
					+                        offset = base + idx * chunk_sz;
 | 
				
			||||||
 | 
					+                        offset_with_pattern = GATHER_BUF_OFFSET + base_with_pattern + idx * chunk_with_pattern;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        char *src = local_host_buf;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        recv(recv_buf + offset, src + offset_with_pattern, sg_lid, req_workitems, dtype, local_world_rank, pattern);
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+                }
 | 
				
			||||||
 | 
					+            }
 | 
				
			||||||
 | 
					+        });
 | 
				
			||||||
 | 
					+    });
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return ret;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+ccl::event dg2_allreduce(const void *src, void *dst, size_t count,
 | 
				
			||||||
 | 
					+                         const ccl_datatype &dtype, ccl::reduction reduction, ccl_comm *comm)
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    ccl::event ret;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    dg2_ll256_allreduce(src, dst, count, dtype, reduction, comm);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    return ret;
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+void dg2_clear()
 | 
				
			||||||
 | 
					+{
 | 
				
			||||||
 | 
					+    for (int i = 0; i < world_size; i++) {
 | 
				
			||||||
 | 
					+        if (i == world_rank)
 | 
				
			||||||
 | 
					+            continue;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        auto host_buf = host_bufs[world_rank];
 | 
				
			||||||
 | 
					+        if (host_buf)
 | 
				
			||||||
 | 
					+            sycl::free(host_buf, q);
 | 
				
			||||||
 | 
					+    }
 | 
				
			||||||
 | 
					+}
 | 
				
			||||||
 | 
					diff --git a/src/dg2/dg2_allreduce.hpp b/src/dg2/dg2_allreduce.hpp
 | 
				
			||||||
 | 
					new file mode 100644
 | 
				
			||||||
 | 
					index 0000000..0506445
 | 
				
			||||||
 | 
					--- /dev/null
 | 
				
			||||||
 | 
					+++ b/src/dg2/dg2_allreduce.hpp
 | 
				
			||||||
 | 
					@@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					+#include "oneapi/ccl.hpp"
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define DG2_NUM (4)
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+#define LL256_BUF_SIZE (32 * 1024 * 1024)
 | 
				
			||||||
 | 
					+#define GATHER_BUF_OFFSET (LL256_BUF_SIZE / 2)
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+void dg2_init(ccl_coll_param ¶m);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+ccl::event dg2_allreduce(const void *src, void *dst, size_t count,
 | 
				
			||||||
 | 
					+                         const ccl_datatype& dtype, ccl::reduction reduction, ccl_comm *comm);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+void dg2_clear();
 | 
				
			||||||
 | 
					-- 
 | 
				
			||||||
 | 
					2.34.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					From 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa Mon Sep 17 00:00:00 2001
 | 
				
			||||||
 | 
					From: Huajun Li <huajun.li@.com>
 | 
				
			||||||
 | 
					Date: Fri, 7 Mar 2025 15:15:35 +0800
 | 
				
			||||||
 | 
					Subject: [PATCH 2/3] optimize req_workgroup calculate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					 src/dg2/dg2_allreduce.cpp | 25 ++-----------------------
 | 
				
			||||||
 | 
					 1 file changed, 2 insertions(+), 23 deletions(-)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					index 15ace74..83270ae 100644
 | 
				
			||||||
 | 
					--- a/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					+++ b/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					@@ -527,30 +527,9 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
 | 
				
			||||||
 | 
					                 auto chunk_sz = req_workitems * LS_SZ;         /* LS_SZ bytes per work-item */
 | 
				
			||||||
 | 
					                 auto chunk_with_pattern = sg_sz * LS_SZ;       /* aligned to 256B */
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-                /* items will be assigned to each rank */
 | 
				
			||||||
 | 
					-                auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
 | 
				
			||||||
 | 
					-                auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
 | 
				
			||||||
 | 
					-                auto req_subgroups = 0;
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					-                if (req_workgroups >= g_sz/l_sz) {
 | 
				
			||||||
 | 
					-                    req_workgroups = g_sz/l_sz;
 | 
				
			||||||
 | 
					-                } else {
 | 
				
			||||||
 | 
					-                    if (group_id == (req_workgroups - 1)) {
 | 
				
			||||||
 | 
					-                        req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					-                        /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
 | 
				
			||||||
 | 
					-                        /* Note:  req_subgroups % (l_sz/sg_sz) might be 0 */
 | 
				
			||||||
 | 
					-                        if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
 | 
				
			||||||
 | 
					-                            if ((per_rank_items % (sg_sz - 1)) != 0) {
 | 
				
			||||||
 | 
					-                                /* FIXME: */
 | 
				
			||||||
 | 
					-                                req_workitems = per_rank_items % (sg_sz - 1);
 | 
				
			||||||
 | 
					-                                chunk_sz = req_workitems * LS_SZ;    /* LS_SZ bytes per work-item */
 | 
				
			||||||
 | 
					-                            }
 | 
				
			||||||
 | 
					-                        }
 | 
				
			||||||
 | 
					-                    }
 | 
				
			||||||
 | 
					-                }
 | 
				
			||||||
 | 
					+		auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-                if (group_id < req_workgroups) {
 | 
				
			||||||
 | 
					+                if (work_left > 0) {
 | 
				
			||||||
 | 
					                     // step 1: push data to next GPU
 | 
				
			||||||
 | 
					                     {
 | 
				
			||||||
 | 
					                         offset = base + local_world_rank * chunk_sz;
 | 
				
			||||||
 | 
					-- 
 | 
				
			||||||
 | 
					2.34.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					From 1c58cc9ede5ca38138a270f9e5ff59bca74f51d4 Mon Sep 17 00:00:00 2001
 | 
				
			||||||
 | 
					From: Huajun Li <huajun.li@.com>
 | 
				
			||||||
 | 
					Date: Wed, 12 Mar 2025 13:29:27 +0800
 | 
				
			||||||
 | 
					Subject: [PATCH 3/3] Revert "optimize req_workgroup calculate" for hang issue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					 src/dg2/dg2_allreduce.cpp | 25 +++++++++++++++++++++++--
 | 
				
			||||||
 | 
					 1 file changed, 23 insertions(+), 2 deletions(-)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					index 83270ae..15ace74 100644
 | 
				
			||||||
 | 
					--- a/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					+++ b/src/dg2/dg2_allreduce.cpp
 | 
				
			||||||
 | 
					@@ -527,9 +527,30 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
 | 
				
			||||||
 | 
					                 auto chunk_sz = req_workitems * LS_SZ;         /* LS_SZ bytes per work-item */
 | 
				
			||||||
 | 
					                 auto chunk_with_pattern = sg_sz * LS_SZ;       /* aligned to 256B */
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-		auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
 | 
				
			||||||
 | 
					+                /* items will be assigned to each rank */
 | 
				
			||||||
 | 
					+                auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
 | 
				
			||||||
 | 
					+                auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
 | 
				
			||||||
 | 
					+                auto req_subgroups = 0;
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                if (req_workgroups >= g_sz/l_sz) {
 | 
				
			||||||
 | 
					+                    req_workgroups = g_sz/l_sz;
 | 
				
			||||||
 | 
					+                } else {
 | 
				
			||||||
 | 
					+                    if (group_id == (req_workgroups - 1)) {
 | 
				
			||||||
 | 
					+                        req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+                        /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
 | 
				
			||||||
 | 
					+                        /* Note:  req_subgroups % (l_sz/sg_sz) might be 0 */
 | 
				
			||||||
 | 
					+                        if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
 | 
				
			||||||
 | 
					+                            if ((per_rank_items % (sg_sz - 1)) != 0) {
 | 
				
			||||||
 | 
					+                                /* FIXME: */
 | 
				
			||||||
 | 
					+                                req_workitems = per_rank_items % (sg_sz - 1);
 | 
				
			||||||
 | 
					+                                chunk_sz = req_workitems * LS_SZ;    /* LS_SZ bytes per work-item */
 | 
				
			||||||
 | 
					+                            }
 | 
				
			||||||
 | 
					+                        }
 | 
				
			||||||
 | 
					+                    }
 | 
				
			||||||
 | 
					+                }
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-                if (work_left > 0) {
 | 
				
			||||||
 | 
					+                if (group_id < req_workgroups) {
 | 
				
			||||||
 | 
					                     // step 1: push data to next GPU
 | 
				
			||||||
 | 
					                     {
 | 
				
			||||||
 | 
					                         offset = base + local_world_rank * chunk_sz;
 | 
				
			||||||
 | 
					-- 
 | 
				
			||||||
 | 
					2.34.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,6 +16,7 @@ COPY ./ccl_torch.patch /tmp/
 | 
				
			||||||
COPY ./vllm_online_benchmark.py ./vllm_offline_inference.py ./vllm_offline_inference_vision_language.py \
 | 
					COPY ./vllm_online_benchmark.py ./vllm_offline_inference.py ./vllm_offline_inference_vision_language.py \
 | 
				
			||||||
     ./payload-1024.lua ./start-vllm-service.sh ./benchmark_vllm_throughput.py ./benchmark_vllm_latency.py \
 | 
					     ./payload-1024.lua ./start-vllm-service.sh ./benchmark_vllm_throughput.py ./benchmark_vllm_latency.py \
 | 
				
			||||||
     ./start-pp_serving-service.sh /llm/
 | 
					     ./start-pp_serving-service.sh /llm/
 | 
				
			||||||
 | 
					COPY ./1ccl_for_multi_arc.patch /build/
 | 
				
			||||||
     
 | 
					     
 | 
				
			||||||
RUN set -eux && \
 | 
					RUN set -eux && \
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
| 
						 | 
					@ -55,7 +56,7 @@ RUN set -eux && \
 | 
				
			||||||
    pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
 | 
					    pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
    # Build torch-ccl
 | 
					    # Build torch-ccl
 | 
				
			||||||
    mkdir /build && \
 | 
					    mkdir -p /build && \
 | 
				
			||||||
    cd /build && \
 | 
					    cd /build && \
 | 
				
			||||||
    git clone https://github.com/intel/torch-ccl.git && \
 | 
					    git clone https://github.com/intel/torch-ccl.git && \
 | 
				
			||||||
    cd torch-ccl && \
 | 
					    cd torch-ccl && \
 | 
				
			||||||
| 
						 | 
					@ -64,8 +65,21 @@ RUN set -eux && \
 | 
				
			||||||
    git submodule update --init --recursive && \
 | 
					    git submodule update --init --recursive && \
 | 
				
			||||||
    # This patch will enable build torch-ccl with pytorch 2.6 environment
 | 
					    # This patch will enable build torch-ccl with pytorch 2.6 environment
 | 
				
			||||||
    git apply /tmp/ccl_torch.patch && \
 | 
					    git apply /tmp/ccl_torch.patch && \
 | 
				
			||||||
    USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel
 | 
					    USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel && \
 | 
				
			||||||
    # File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
 | 
					    # File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
 | 
				
			||||||
 | 
					    # Build oneCCL
 | 
				
			||||||
 | 
					    pip install ninja && \
 | 
				
			||||||
 | 
					    cd /build/ && \
 | 
				
			||||||
 | 
					    git clone https://github.com/analytics-zoo/oneCCL.git && \
 | 
				
			||||||
 | 
					    cd oneCCL && \
 | 
				
			||||||
 | 
					    git checkout 3afa1bb7936f57683a2503c34b29c0daca6a9ccb && \
 | 
				
			||||||
 | 
					    git apply /build/1ccl_for_multi_arc.patch && \
 | 
				
			||||||
 | 
					    mkdir build && \
 | 
				
			||||||
 | 
					    cd build && \
 | 
				
			||||||
 | 
					    cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp  -DCMAKE_BUILD_TYPE=MinSizeRel && \
 | 
				
			||||||
 | 
					    # File path: /build/oneCCL/build/src/libccl.so.1.0
 | 
				
			||||||
 | 
					    ninja
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Second stage: Final runtime image
 | 
					# Second stage: Final runtime image
 | 
				
			||||||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
 | 
					FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
 | 
				
			||||||
| 
						 | 
					@ -73,6 +87,11 @@ FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
 | 
				
			||||||
# Copy the built torch-ccl package from the build stage
 | 
					# Copy the built torch-ccl package from the build stage
 | 
				
			||||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
 | 
					COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
 | 
				
			||||||
COPY --from=build /llm/ /llm/
 | 
					COPY --from=build /llm/ /llm/
 | 
				
			||||||
 | 
					COPY --from=build /build/oneCCL/build/src/libccl.so.1.0 /opt/intel/1ccl-wks/lib/
 | 
				
			||||||
 | 
					COPY --from=build /build/oneCCL/build/src/libccl.so.1 /opt/intel/1ccl-wks/lib/
 | 
				
			||||||
 | 
					COPY --from=build /build/oneCCL/build/src/libccl.so /opt/intel/1ccl-wks/lib/
 | 
				
			||||||
 | 
					COPY ./vllm_for_multi_arc.patch /llm/
 | 
				
			||||||
 | 
					COPY ./setvars.sh /opt/intel/1ccl-wks/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ARG http_proxy
 | 
					ARG http_proxy
 | 
				
			||||||
ARG https_proxy
 | 
					ARG https_proxy
 | 
				
			||||||
| 
						 | 
					@ -128,10 +147,6 @@ RUN set -eux && \
 | 
				
			||||||
    # Install torch-ccl
 | 
					    # Install torch-ccl
 | 
				
			||||||
    pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
 | 
					    pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
    # Install Internal oneccl
 | 
					 | 
				
			||||||
    cd /opt && \
 | 
					 | 
				
			||||||
    wget https://sourceforge.net/projects/oneccl-wks/files/2025.0.0.6.6-release/oneccl_wks_installer_2025.0.0.6.6.sh && \
 | 
					 | 
				
			||||||
    bash oneccl_wks_installer_2025.0.0.6.6.sh && \
 | 
					 | 
				
			||||||
    apt-get update && \
 | 
					    apt-get update && \
 | 
				
			||||||
    apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
 | 
					    apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
 | 
				
			||||||
    # 
 | 
					    # 
 | 
				
			||||||
| 
						 | 
					@ -154,13 +169,16 @@ RUN set -eux && \
 | 
				
			||||||
    rm -rf /tmp/neo && \
 | 
					    rm -rf /tmp/neo && \
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
    # Install vllm
 | 
					    # Install vllm
 | 
				
			||||||
    git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
 | 
					    git clone -b v0.6.6.post1 https://github.com/vllm-project/vllm /llm/vllm && \
 | 
				
			||||||
    cd /llm/vllm && \
 | 
					    cd /llm/vllm && \
 | 
				
			||||||
 | 
					    git apply /llm/vllm_for_multi_arc.patch && \
 | 
				
			||||||
    pip install setuptools-scm && \
 | 
					    pip install setuptools-scm && \
 | 
				
			||||||
    pip install --upgrade cmake && \
 | 
					    pip install --upgrade cmake && \
 | 
				
			||||||
    VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
 | 
					    VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
 | 
				
			||||||
 | 
					    rm -rf /llm/vllm_for_multi_arc.patch && \
 | 
				
			||||||
    pip install mpi4py fastapi uvicorn openai && \
 | 
					    pip install mpi4py fastapi uvicorn openai && \
 | 
				
			||||||
    pip install ray
 | 
					    pip install ray
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
WORKDIR /llm/
 | 
					WORKDIR /llm/
 | 
				
			||||||
ENTRYPOINT ["bash", "/llm/start-vllm-service.sh"]
 | 
					ENTRYPOINT ["bash", "/llm/start-vllm-service.sh"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										4
									
								
								docker/llm/serving/xpu/docker/setvars.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docker/llm/serving/xpu/docker/setvars.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,4 @@
 | 
				
			||||||
 | 
					#!/bin/sh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export CCL_DG2_ALLREDUCE=1
 | 
				
			||||||
 | 
					export LD_LIBRARY_PATH=/opt/intel/1ccl-wks/lib:$LD_LIBRARY_PATH
 | 
				
			||||||
							
								
								
									
										40781
									
								
								docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40781
									
								
								docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Loading…
	
		Reference in a new issue