Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[XPU] update xhpc to support VL model pretraining and inference #75870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
yongqiangma merged 4 commits into PaddlePaddle:release/3.2 from cqulilujia:xhpc
Oct 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[XPU] support index_elementwise_get kernel (#75486)
  • Loading branch information
cqulilujia committed Oct 15, 2025
commit 50a67078e97a88e222767d96241e171eb2bc8efa
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ add_compile_definitions(XPUAPI_NOT_INCLUDE_DEPRECATED)
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "dev/20250922")
endif()
set(XPU_XCCL_BASE_VERSION "3.0.3.1") # For XRE5
set(XPU_XCCL_BASE_VERSION "3.0.3.3") # For XRE5
if(NOT DEFINED XPU_XFT_BASE_VERSION)
set(XPU_XFT_BASE_VERSION "20250507/xpu3")
endif()
Expand Down
33 changes: 33 additions & 0 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,39 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"index_elementwise_get",
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64})},
{"index_elementwise_put",
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64})},
{"index_elementwise_put_with_tensor",
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64})},
{"index_put",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/cpu/index_elementwise_get_kernel.cc
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ namespace phi {
template <typename T, typename IndexT = int>
void CPUIndexElementwiseGetKernel(const phi::CPUContext& dev_ctx,
const DenseTensor& input,
const std::vector<const DenseTensor*> index,
const std::vector<const DenseTensor*>& index,
const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& input_strides,
const std::vector<int64_t>& index_dims,
const std::vector<int64_t>& index_stride,
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
DenseTensor* output) {
int64_t numel = 0;
Expand All @@ -42,7 +42,7 @@ void CPUIndexElementwiseGetKernel(const phi::CPUContext& dev_ctx,
auto strides = std::array<int64_t, DDim::kMaxRank>{};
for (int64_t i = 0; i < num_indices; i++) {
sizes[i] = index_dims[i];
strides[i] = index_stride[i];
strides[i] = index_strides[i];
}
std::array<int64_t*, 3> strides_array;
std::vector<int64_t> desired_shape;
Expand Down Expand Up @@ -97,7 +97,7 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& input_strides,
const std::vector<int64_t>& index_dims,
const std::vector<int64_t>& index_stride,
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
Expand All @@ -124,7 +124,7 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
input_dims,
input_strides,
index_dims,
index_stride,
index_strides,
slice_offset,
out);
}
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/kernels/cpu/index_elementwise_put_grad_kernel.cc
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ void CPUIndexElementwisePutGradKernel(
auto offset_calc =
funcs::CPUmake_offset_calculator_put<3>(desired_shape, strides_array);
const int64_t N = numel;
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
using dtype = funcs::OpaqueType<sizeof(T)>;
if (!value_grad) {
char* out_ptr = reinterpret_cast<char*>(x_grad->data<T>());
Expand Down
20 changes: 13 additions & 7 deletions paddle/phi/kernels/cpu/index_elementwise_put_kernel.cc
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ void CPUIndexElementwisePutWithTensorKernel(
auto offset_calc =
funcs::CPUmake_offset_calculator_put<3>(desired_shape, strides_array);
const int64_t N = numel;
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
using dtype = funcs::OpaqueType<sizeof(T)>;
const char* in_ptr = reinterpret_cast<const char*>(value.data<T>());
char* out_ptr = reinterpret_cast<char*>(output_);
Expand Down Expand Up @@ -150,14 +153,17 @@ void CPUIndexElementwisePutKernel(const phi::CPUContext& dev_ctx,
auto offset_calc =
funcs::CPUmake_offset_calculator_put<3>(desired_shape, strides_array);
const int64_t N = numel;
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
char* out_ptr = reinterpret_cast<char*>(output_);
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
char* out_ptr = reinterpret_cast<char*>(output_) + slice_offset;
if (index.size() == 1 && index[0]->dtype() == phi::DataType::BOOL) {
const bool* mask_data = index[0]->data<bool>();
for (int64_t idx = 0; idx < N; idx++) {
const auto offsets = offset_calc.cpu_get(idx);
char* const out_data = out_ptr + offsets[0] + slice_offset;
char* const out_data = out_ptr + offsets[0];
if (mask_data[idx]) {
*reinterpret_cast<T*>(out_data) = value_T;
}
Expand All @@ -166,7 +172,7 @@ void CPUIndexElementwisePutKernel(const phi::CPUContext& dev_ctx,
auto index_ptrs = funcs::GetIndexDataPtrs<IndexT>(index);
for (int64_t idx = 0; idx < N; idx++) {
const auto offsets = offset_calc.cpu_get(idx);
char* const out_data = out_ptr + offsets[0] + slice_offset;
char* const out_data = out_ptr + offsets[0];
int64_t offset = 0;
for (int64_t i = 0; i < num_indices; i++) {
int64_t index = *reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/index_elementwise_utils.h
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct alignas(N) OpaqueType {

template <typename IndexT>
std::array<char*, DDim::kMaxRank> GetIndexDataPtrs(
const std::vector<const DenseTensor*> index) {
const std::vector<const DenseTensor*>& index) {
std::array<char*, DDim::kMaxRank> index_ptrs{};

PADDLE_ENFORCE_LE(index.size(),
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace phi {
template <typename T, typename IndexT = int>
void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
const DenseTensor& input,
const std::vector<const DenseTensor*> index,
const std::vector<const DenseTensor*>& index,
const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& input_strides,
const std::vector<int64_t>& index_dims,
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ void GPUIndexElementwisePutGradKernel(
auto offset_calc =
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
const int64_t N = numel;
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down
14 changes: 10 additions & 4 deletions paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ void GPUIndexElementwisePutKernel(const phi::GPUContext& dev_ctx,
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);

const int64_t N = numel;
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down Expand Up @@ -159,8 +162,11 @@ void GPUIndexElementwisePutWithTensorKernel(
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);

const int64_t N = numel;
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/kernels/stride/indexing_kernel.cu
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,
funcs::OffsetCalculator offset_calc = funcs::make_offset_calculator<3>(iter);

const int64_t N = iter.numel();
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down
171 changes: 171 additions & 0 deletions paddle/phi/kernels/xpu/index_elementwise_get_kernel.cc
View file Open in desktop
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/index_elementwise_get_kernel.h"

#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/index_elementwise.h"
#include "paddle/phi/kernels/funcs/stride_utils.h"

namespace phi {
template <typename T, typename Context, typename IndexT = int>
void XPUIndexElementwiseGetKernel(const Context& dev_ctx,
const DenseTensor& input,
const std::vector<const DenseTensor*>& index,
const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& input_strides,
const std::vector<int64_t>& index_dims,
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
DenseTensor* output) {
int64_t numel = 0;
int64_t num_indices = 0;
std::vector<int64_t> shape_tmp;
std::vector<int64_t> stride_tmp;
funcs::cal_shape_stride(index_dims, &num_indices, &shape_tmp, &stride_tmp);

auto sizes = std::array<int64_t, DDim::kMaxRank>{};
auto strides = std::array<int64_t, DDim::kMaxRank>{};
for (int64_t i = 0; i < num_indices; i++) {
sizes[i] = index_dims[i];
strides[i] = index_strides[i];
}
std::array<int64_t*, 3> strides_array;
std::vector<int64_t> desired_shape;
std::array<std::vector<int64_t>, 3> strides_vec;
funcs::IndexGetStride<3>(input_dims,
input_strides,
phi::SizeOf(input.dtype()),
std::vector<int64_t>(),
std::vector<int64_t>(),
phi::SizeOf(input.dtype()),
shape_tmp,
stride_tmp,
phi::SizeOf(index[0]->dtype()),
&desired_shape,
&strides_array,
&numel,
strides_vec);
const int64_t N = output->numel();
PADDLE_ENFORCE_GE(
N, 0, common::errors::InvalidArgument("Output numel must >= 0"));
PADDLE_ENFORCE_LE(
N,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument("Output numel must <= INT32_MAX"));

dev_ctx.template Alloc<T>(output);
using XPUType = typename XPUTypeTrait<T>::Type;
using XPUTypeIndexT = typename XPUTypeTrait<IndexT>::Type;

// passed vector params for XPU
std::vector<const XPUTypeIndexT*> index_ptrs_vec;
std::vector<int64_t> index_numel_vec;
for (int i = 0; i < num_indices; i++) {
// since XPU WRAPPER_CHECK_PTR only supports original GM ptrs, so we pass
// the IndexT* type ptrs, which is different from the CPU/GPU's char* ptr.
index_ptrs_vec.push_back(
reinterpret_cast<const XPUTypeIndexT*>(index[i]->data<IndexT>()));
// index_numel_vec is for the length of WRAPPER_CHECK_PTR
index_numel_vec.push_back(index[i]->numel());
}
std::vector<int64_t> sizes_vec =
std::vector<int64_t>(sizes.begin(), sizes.begin() + num_indices);
std::vector<int64_t> orig_strides_vec =
std::vector<int64_t>(strides.begin(), strides.begin() + num_indices);
std::vector<std::vector<int64_t>> strides_vec_vec =
std::vector<std::vector<int64_t>>(strides_vec.begin(), strides_vec.end());

const char* in_ptr =
reinterpret_cast<const char*>(input.data<T>()) + slice_offset;
char* out_ptr = reinterpret_cast<char*>(output->data<T>());

// for checkptr and checksum in XPU
int64_t data_size_in = input.Holder()->size() - input.meta().offset;
int64_t data_size_out = output->Holder()->size() - output->meta().offset;

bool is_get = true;
int r = xpu::index_elementwise_tensor<XPUType, XPUTypeIndexT>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(in_ptr), // XPU ptr
reinterpret_cast<XPUType*>(out_ptr), // XPU ptr
index_ptrs_vec, // vec of XPU ptrs
index_numel_vec, // CPU vec
desired_shape, // CPU vec
sizes_vec, // CPU vec
orig_strides_vec, // CPU vec
strides_vec_vec, // CPU vec
N, // int64_t
data_size_in, // int64_t
data_size_out, // int64_t
is_get); // true for get, false for put
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_elementwise_tensor_get");
}

template <typename T, typename Context>
void IndexElementwiseGetKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& index,
const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& input_strides,
const std::vector<int64_t>& index_dims,
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
DenseTensor* out) {
const auto& index_type = index[0]->dtype();
PADDLE_ENFORCE_EQ(index_type == phi::DataType::INT64,
true,
common::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s].",
index_type,
phi::DataType::INT64));

auto out_dims = out->dims();
if (out_dims.size() > 0) {
std::vector<int64_t> output_dims(input_dims);
out->Resize(phi::make_ddim(output_dims));
}
dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) return;
XPUIndexElementwiseGetKernel<T, Context, int64_t>(dev_ctx,
x,
index,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
out);
}

} // namespace phi

PD_REGISTER_KERNEL(index_elementwise_get,
XPU,
ALL_LAYOUT,
phi::IndexElementwiseGetKernel,
bool,
float,
double,
int,
int8_t,
int64_t,
int16_t,
uint8_t,
phi::float16,
phi::bfloat16) {}
Loading

AltStyle によって変換されたページ (->オリジナル) /