Skip to content

Commit 80af9c3

Browse files
Implement dpnp.linalg.lu_solve() batch inputs (#2619)
This PR suggests extending `dpnp.linalg.lu_solve()` #2575 for batch arrays and adding new lapack `getrs_batch` extension
1 parent 800f642 commit 80af9c3

File tree

12 files changed

+840
-56
lines changed

12 files changed

+840
-56
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ This release changes the license from `BSD-2-Clause` to `BSD-3-Clause`.
1111
### Added
1212

1313
* Added the docstrings to `dpnp.linalg.LinAlgError` exception [#2613](https://github.com/IntelPython/dpnp/pull/2613)
14+
* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2619](https://github.com/IntelPython/dpnp/pull/2619)
1415

1516
### Changed
1617

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ set(_module_src
4141
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
4242
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
4343
${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp
44+
${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp
4445
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
4546
${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp
4647
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp

dpnp/backend/extensions/lapack/getrs.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ namespace type_utils = dpctl::tensor::type_utils;
5151
using ext::common::init_dispatch_vector;
5252

5353
typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &,
54-
oneapi::mkl::transpose,
54+
const oneapi::mkl::transpose,
5555
const std::int64_t,
5656
const std::int64_t,
5757
char *,
58-
std::int64_t,
58+
const std::int64_t,
5959
std::int64_t *,
6060
char *,
61-
std::int64_t,
61+
const std::int64_t,
6262
std::vector<sycl::event> &,
6363
const std::vector<sycl::event> &);
6464

@@ -70,10 +70,10 @@ static sycl::event getrs_impl(sycl::queue &exec_q,
7070
const std::int64_t n,
7171
const std::int64_t nrhs,
7272
char *in_a,
73-
std::int64_t lda,
73+
const std::int64_t lda,
7474
std::int64_t *ipiv,
7575
char *in_b,
76-
std::int64_t ldb,
76+
const std::int64_t ldb,
7777
std::vector<sycl::event> &host_task_events,
7878
const std::vector<sycl::event> &depends)
7979
{
@@ -234,7 +234,7 @@ std::pair<sycl::event, sycl::event>
234234
throw py::value_error("The right-hand sides array "
235235
"must be F-contiguous");
236236
}
237-
if (!is_ipiv_array_c_contig || !is_ipiv_array_f_contig) {
237+
if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) {
238238
throw py::value_error("The array of pivot indices "
239239
"must be contiguous");
240240
}

dpnp/backend/extensions/lapack/getrs.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,23 @@ extern std::pair<sycl::event, sycl::event>
4040
const dpctl::tensor::usm_ndarray &a_array,
4141
const dpctl::tensor::usm_ndarray &ipiv_array,
4242
const dpctl::tensor::usm_ndarray &b_array,
43-
oneapi::mkl::transpose trans,
43+
const oneapi::mkl::transpose trans,
4444
const std::vector<sycl::event> &depends = {});
4545

46+
extern std::pair<sycl::event, sycl::event>
47+
getrs_batch(sycl::queue &exec_q,
48+
const dpctl::tensor::usm_ndarray &a_array,
49+
const dpctl::tensor::usm_ndarray &ipiv_array,
50+
const dpctl::tensor::usm_ndarray &b_array,
51+
const oneapi::mkl::transpose trans,
52+
const std::int64_t n,
53+
const std::int64_t nrhs,
54+
const std::int64_t stride_a,
55+
const std::int64_t stride_ipiv,
56+
const std::int64_t stride_b,
57+
const std::int64_t batch_size,
58+
const std::vector<sycl::event> &depends = {});
59+
4660
extern void init_getrs_dispatch_vector(void);
61+
extern void init_getrs_batch_dispatch_vector(void);
4762
} // namespace dpnp::extensions::lapack

0 commit comments

Comments
 (0)