Skip to content

Commit 5e9b356

Browse files
authored
1 parent c95da49 commit 5e9b356

10 files changed

+386
-60
lines changed

BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ cc_library(
994994
":linalg_passes",
995995
":reference_api",
996996
":reference_configuration",
997-
":stablehlo_dialect_capi_objects",
997+
":stablehlo_dialect_capi",
998998
":stablehlo_ops",
999999
":stablehlo_passes",
10001000
":stablehlo_portable_api",

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "0e779ad4998ef65907502101c5b82ede05ddfa4e"
20+
LLVM_COMMIT = "43d71baae36c8d8b5a9995aa35efebe09cc9c2d6"
2121

22-
LLVM_SHA256 = "d5c2560b2d9ce3ced7951113f2b5d1ea428665678f4dcb1fb8780eb1219ca615"
22+
LLVM_SHA256 = "436af8b4c3403e251ab0b7a471eda7df6063f9da9d22ccbe498f3115cd35225a"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0e779ad4998ef65907502101c5b82ede05ddfa4e
1+
43d71baae36c8d8b5a9995aa35efebe09cc9c2d6

stablehlo/dialect/ChloOps.cpp

Lines changed: 121 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@ limitations under the License.
1616

1717
#include "stablehlo/dialect/ChloOps.h"
1818

19+
#include <algorithm>
1920
#include <cassert>
2021
#include <cstdint>
22+
#include <iostream>
23+
#include <iterator>
2124
#include <optional>
25+
#include <string>
2226

2327
#include "llvm/ADT/STLExtras.h"
2428
#include "llvm/ADT/SmallVector.h"
@@ -426,12 +430,12 @@ namespace {
426430
// Mode 1, where the ragged dimension is an lhs non-contracting dim (m).
427431
// lhs : [b, m, k]
428432
// rhs : [g, b, k, n]
429-
// group_sizes : [g]
433+
// group_sizes : [b, g]
430434
// result : [b, m, n]
431435
// Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k).
432436
// lhs : [b, m, k]
433437
// rhs : [b, k, n]
434-
// group_sizes : [g]
438+
// group_sizes : [b, g]
435439
// result : [g, b, m, n]
436440
// Mode 3, where the ragged dimension is an lhs/rhs batch dim (b).
437441
// lhs : [b, m, k]
@@ -440,9 +444,18 @@ namespace {
440444
// result : [b, m, n]
441445
// As with dot_general, the lhs and rhs can have arbitrary batching,
442446
// contracting and non-contracting dimensions.
447+
// The group_sizes arg has the shape [b...,x...,g], where:
448+
// - b... are all the lhs batch dims before (outer-to) the lhs ragged dim,
449+
// - x... are,
450+
// - in mode 1, all the lhs non-contracting dims before the lhs ragged dim,
451+
// - in mode 2, all the lhs contracting dims before the lhs ragged dim, and
452+
// - in mode 3, empty;
453+
// - g is the number of groups in the lhs ragged dim.
443454
// Additionally:
444455
// - In all modes, the lhs must have exactly one ragged dimension.
445456
// - In mode 1, the rhs must have exactly one group dimension.
457+
// - If a group_sizes of shape [g] is passed, it is broadcasted according to
458+
// the rules above.
446459
LogicalResult checkRaggedDotConstraints(
447460
std::optional<Location> location, RankedTensorType rankedLhsType,
448461
RankedTensorType rankedRhsType, RankedTensorType rankedGroupSizesType,
@@ -452,14 +465,6 @@ LogicalResult checkRaggedDotConstraints(
452465
ArrayRef<int64_t> rhsContractingDimensions,
453466
ArrayRef<int64_t> lhsRaggedDimensions,
454467
ArrayRef<int64_t> rhsGroupDimensions) {
455-
// Check that the group sizes has rank=1.
456-
if (rankedGroupSizesType.getRank() != 1) {
457-
return emitOptionalError(
458-
location, "expected rank of group_sizes of ragged dot to be 1, got ",
459-
rankedGroupSizesType.getRank());
460-
}
461-
auto numGroups = rankedGroupSizesType.getDimSize(0);
462-
463468
// Check that there is exactly one lhs ragged dimension.
464469
if (lhsRaggedDimensions.size() != 1) {
465470
return emitOptionalError(
@@ -474,6 +479,82 @@ LogicalResult checkRaggedDotConstraints(
474479
return failure();
475480
}
476481

482+
enum Mode {
483+
// Ragged non-contracting (m): [b,m,k], [g,b,k,n], [b,g] -> [b,m,n].
484+
kNonContracting,
485+
// Ragged contracting (k): [b,m,k], [b,k,n], [b,g] -> [g,b,m,n].
486+
kContracting,
487+
// Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
488+
kBatch
489+
};
490+
Mode mode;
491+
if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim)) {
492+
mode = kBatch;
493+
} else if (llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) {
494+
mode = kContracting;
495+
} else {
496+
mode = kNonContracting;
497+
}
498+
499+
// Validate the shape of group_sizes.
500+
{
501+
// Construct the expected shape [b...,x...,g] of group_sizes.
502+
SmallVector<int64_t> prefixDims;
503+
prefixDims.reserve(rankedLhsType.getRank() - 1);
504+
prefixDims.insert(prefixDims.end(), lhsBatchingDimensions.begin(),
505+
lhsBatchingDimensions.end());
506+
switch (mode) {
507+
case kBatch:
508+
prefixDims.resize(
509+
std::distance(lhsBatchingDimensions.begin(),
510+
llvm::find(lhsBatchingDimensions, lhsRaggedDim)));
511+
break;
512+
case kContracting:
513+
prefixDims.insert(prefixDims.end(), lhsContractingDimensions.begin(),
514+
llvm::find(lhsContractingDimensions, lhsRaggedDim));
515+
break;
516+
case kNonContracting:
517+
for (int64_t i = 0; i < lhsRaggedDim; ++i) {
518+
if (!llvm::is_contained(lhsBatchingDimensions, i) &&
519+
!llvm::is_contained(lhsContractingDimensions, i)) {
520+
prefixDims.push_back(i);
521+
}
522+
}
523+
break;
524+
}
525+
SmallVector<int64_t> expectedPrefix;
526+
expectedPrefix.reserve(prefixDims.size());
527+
for (const int64_t dim : prefixDims) {
528+
expectedPrefix.push_back(rankedLhsType.getDimSize(dim));
529+
}
530+
531+
// Validate the actual shape, if it was passed as something other than [g].
532+
if (rankedGroupSizesType.getRank() != 1) {
533+
if (rankedGroupSizesType.getRank() !=
534+
static_cast<int64_t>(expectedPrefix.size()) + 1) {
535+
return emitOptionalError(location, "expected group_sizes to have rank ",
536+
expectedPrefix.size() + 1, ", got ",
537+
rankedGroupSizesType.getRank());
538+
}
539+
auto groupSizesShape = rankedGroupSizesType.getShape();
540+
if (!std::equal(expectedPrefix.begin(), expectedPrefix.end(),
541+
groupSizesShape.begin())) {
542+
auto nonEmptyShapeStr = [](ArrayRef<int64_t> shape) {
543+
std::string s = "";
544+
for (size_t i = 0; i < shape.size() - 1; ++i) {
545+
s += std::to_string(shape[i]) + ", ";
546+
}
547+
return s + std::to_string(shape.back());
548+
};
549+
return emitOptionalError(
550+
location, "group_sizes is expected to have shape [",
551+
nonEmptyShapeStr(expectedPrefix), ", ", groupSizesShape.back(),
552+
"], got [", nonEmptyShapeStr(groupSizesShape), "]");
553+
}
554+
}
555+
}
556+
const int64_t numGroups = rankedGroupSizesType.getShape().back();
557+
477558
// Validate basic properties of the rhs group dimension(s).
478559
for (auto rhsGroupDim : rhsGroupDimensions) {
479560
if (failed(hlo::checkDimInBounds(location, rhsGroupDim,
@@ -491,32 +572,34 @@ LogicalResult checkRaggedDotConstraints(
491572
return failure();
492573
}
493574

494-
if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim) ||
495-
llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) {
496-
// Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
497-
// Ragged contracting (k): [b,m,k], [b,k,n], [g] -> [g,b,m,n].
498-
if (!rhsGroupDimensions.empty()) {
499-
return emitOptionalError(
500-
location,
501-
"There must be zero group dimensions in the rhs when the "
502-
"ragged dimension is batch or contracting.");
503-
}
504-
} else {
505-
// Ragged non-contracting (m): [b,m,k], [g,b,k,n], [g] -> [b,m,n].
506-
if (rhsGroupDimensions.size() != 1) {
507-
return emitOptionalError(
508-
location,
509-
"There must be exactly one group dimension in the rhs when the lhs "
510-
"ragged dimension is non-contracting.");
511-
}
512-
// Compare the group dimension size with the number of groups.
513-
const int64_t rhsGroupDim = rhsGroupDimensions[0];
514-
if (!hlo::verifyCompatibleDims(numGroups,
515-
rankedRhsType.getDimSize(rhsGroupDim))) {
516-
return emitOptionalError(
517-
location, "group_sizes is expected to have shape=[",
518-
rankedRhsType.getDimSize(rhsGroupDim), "], got [", numGroups, "]");
519-
}
575+
switch (mode) {
576+
case kBatch:
577+
[[fallthrough]];
578+
case kContracting:
579+
if (!rhsGroupDimensions.empty()) {
580+
return emitOptionalError(
581+
location,
582+
"There must be zero group dimensions in the rhs when the "
583+
"ragged dimension is batch or contracting.");
584+
}
585+
break;
586+
case kNonContracting:
587+
if (rhsGroupDimensions.size() != 1) {
588+
return emitOptionalError(
589+
location,
590+
"There must be exactly one group dimension in the rhs when the lhs "
591+
"ragged dimension is non-contracting.");
592+
}
593+
// Compare the group dimension size with the number of groups.
594+
const int64_t rhsGroupDim = rhsGroupDimensions[0];
595+
if (!hlo::verifyCompatibleDims(numGroups,
596+
rankedRhsType.getDimSize(rhsGroupDim))) {
597+
return emitOptionalError(
598+
location,
599+
"rhs group dimension is expected to have size=", numGroups,
600+
", got ", rankedRhsType.getDimSize(rhsGroupDim));
601+
}
602+
break;
520603
}
521604
return success();
522605
}
@@ -530,10 +613,10 @@ SmallVector<int64_t> inferRaggedDotOutputDimensions(
530613
ArrayRef<int64_t> rhsContractingDimensions,
531614
ArrayRef<int64_t> lhsRaggedDimensions,
532615
ArrayRef<int64_t> rhsGroupDimensions) {
533-
// Must have already checked that group_sizes is 1-D.
534-
const int64_t numGroups = rankedGroupSizesType.getDimSize(0);
535616
// Must have already checked that there is exactly one lhs ragged dim.
536617
const int64_t lhsRaggedDim = lhsRaggedDimensions[0];
618+
// Must have already checked the shape of group_sizes.
619+
const int64_t numGroups = rankedGroupSizesType.getShape().back();
537620

538621
SmallVector<int64_t> dimensions;
539622
// Add the group dimension to the result shape in case of ragged contracting.

stablehlo/dialect/ChloOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,12 +869,12 @@ def CHLO_RaggedDotOp : CHLO_Op<"ragged_dot",
869869
most one group dimension. The op has three modes, depending on the kind of
870870
the lhs ragged dimension.
871871

872-
In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`.
872+
In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [b,g] -> [b,m,n]`.
873873
Here the ragged dimension is an lhs non-contracting dimension (`m`). The
874874
dimensions `b` and `k` represent batch and contracting dimensions
875875
respectively. The rhs is required to have a group dimension (`g`).
876876

877-
In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`.
877+
In mode 2, the shape-signature is `[b,m,k], [b,k,n], [b,g] -> [g,b,m,n]`.
878878
Here the ragged dimension is an lhs/rhs contracting dimension (`k`).
879879

880880
In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here

stablehlo/tests/ops_chlo.mlir

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rh
146146
// -----
147147

148148
func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> {
149-
// @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
149+
// @expected-error@+1 {{expected group_sizes to have rank 1, got 2}}
150150
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
151151
ragged_dot_dimension_numbers = #chlo.ragged_dot<
152152
lhs_batching_dimensions = [],
@@ -163,8 +163,79 @@ func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs :
163163

164164
// -----
165165

166-
func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
167-
// @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
166+
func.func @ragged_dot_mode1_group_sizes_broadcasted(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<19x17x11x7xf32> {
167+
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
168+
ragged_dot_dimension_numbers = #chlo.ragged_dot<
169+
lhs_batching_dimensions = [0],
170+
rhs_batching_dimensions = [1],
171+
lhs_contracting_dimensions = [3],
172+
rhs_contracting_dimensions = [2],
173+
lhs_ragged_dimensions = [2],
174+
rhs_group_dimensions = [0]
175+
>,
176+
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
177+
} : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<3xi64>) -> tensor<19x17x11x7xf32>
178+
func.return %0 : tensor<19x17x11x7xf32>
179+
}
180+
181+
// -----
182+
183+
func.func @ragged_dot_mode1_group_sizes_incorrect_shape(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32> {
184+
// @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}}
185+
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
186+
ragged_dot_dimension_numbers = #chlo.ragged_dot<
187+
lhs_batching_dimensions = [0],
188+
rhs_batching_dimensions = [1],
189+
lhs_contracting_dimensions = [3],
190+
rhs_contracting_dimensions = [2],
191+
lhs_ragged_dimensions = [2],
192+
rhs_group_dimensions = [0]
193+
>,
194+
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
195+
} : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32>
196+
func.return %0 : tensor<19x17x11x7xf32>
197+
}
198+
199+
// -----
200+
201+
func.func @ragged_dot_mode2_group_sizes_incorrect_shape(%lhs : tensor<19x11x17x5xf32>, %rhs : tensor<19x17x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32> {
202+
// @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}}
203+
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
204+
ragged_dot_dimension_numbers = #chlo.ragged_dot<
205+
lhs_batching_dimensions = [0],
206+
rhs_batching_dimensions = [0],
207+
lhs_contracting_dimensions = [2,3],
208+
rhs_contracting_dimensions = [1,2],
209+
lhs_ragged_dimensions = [3],
210+
rhs_group_dimensions = []
211+
>,
212+
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
213+
} : (tensor<19x11x17x5xf32>, tensor<19x17x5x7xf32>, tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32>
214+
func.return %0 : tensor<3x19x11x7xf32>
215+
}
216+
217+
// -----
218+
219+
func.func @ragged_dot_mode3_group_sizes_incorrect_shape(%lhs : tensor<17x19x11x5xf32>, %rhs : tensor<17x19x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<17x19x11x7xf32> {
220+
// @expected-error@+1 {{group_sizes is expected to have shape [17, 3], got [19, 3]}}
221+
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
222+
ragged_dot_dimension_numbers = #chlo.ragged_dot<
223+
lhs_batching_dimensions = [0,1],
224+
rhs_batching_dimensions = [0,1],
225+
lhs_contracting_dimensions = [3],
226+
rhs_contracting_dimensions = [2],
227+
lhs_ragged_dimensions = [1],
228+
rhs_group_dimensions = []
229+
>,
230+
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
231+
} : (tensor<17x19x11x5xf32>, tensor<17x19x5x7xf32>, tensor<19x3xi64>) -> tensor<17x19x11x7xf32>
232+
func.return %0 : tensor<17x19x11x7xf32>
233+
}
234+
235+
// -----
236+
237+
func.func @ragged_dot_incorrect_group_dim_size(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
238+
// @expected-error@+1 {{rhs group dimension is expected to have size=2, got 3}}
168239
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
169240
ragged_dot_dimension_numbers = #chlo.ragged_dot<
170241
lhs_batching_dimensions = [],

0 commit comments

Comments
 (0)