@@ -16,9 +16,13 @@ limitations under the License.
16
16
17
17
#include " stablehlo/dialect/ChloOps.h"
18
18
19
+ #include < algorithm>
19
20
#include < cassert>
20
21
#include < cstdint>
22
+ #include < iostream>
23
+ #include < iterator>
21
24
#include < optional>
25
+ #include < string>
22
26
23
27
#include " llvm/ADT/STLExtras.h"
24
28
#include " llvm/ADT/SmallVector.h"
@@ -426,12 +430,12 @@ namespace {
426
430
// Mode 1, where the ragged dimension is an lhs non-contracting dim (m).
427
431
// lhs : [b, m, k]
428
432
// rhs : [g, b, k, n]
429
- // group_sizes : [g]
433
+ // group_sizes : [b, g]
430
434
// result : [b, m, n]
431
435
// Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k).
432
436
// lhs : [b, m, k]
433
437
// rhs : [b, k, n]
434
- // group_sizes : [g]
438
+ // group_sizes : [b, g]
435
439
// result : [g, b, m, n]
436
440
// Mode 3, where the ragged dimension is an lhs/rhs batch dim (b).
437
441
// lhs : [b, m, k]
@@ -440,9 +444,18 @@ namespace {
440
444
// result : [b, m, n]
441
445
// As with dot_general, the lhs and rhs can have arbitrary batching,
442
446
// contracting and non-contracting dimensions.
447
+ // The group_sizes arg has the shape [b...,x...,g], where:
448
+ // - b... are all the lhs batch dims before (outer-to) the lhs ragged dim,
449
+ // - x... are,
450
+ // - in mode 1, all the lhs non-contracting dims before the lhs ragged dim,
451
+ // - in mode 2, all the lhs contracting dims before the lhs ragged dim, and
452
+ // - in mode 3, empty;
453
+ // - g is the number of groups in the lhs ragged dim.
443
454
// Additionally:
444
455
// - In all modes, the lhs must have exactly one ragged dimension.
445
456
// - In mode 1, the rhs must have exactly one group dimension.
457
+ // - If a group_sizes of shape [g] is passed, it is broadcasted according to
458
+ // the rules above.
446
459
LogicalResult checkRaggedDotConstraints (
447
460
std::optional<Location> location, RankedTensorType rankedLhsType,
448
461
RankedTensorType rankedRhsType, RankedTensorType rankedGroupSizesType,
@@ -452,14 +465,6 @@ LogicalResult checkRaggedDotConstraints(
452
465
ArrayRef<int64_t > rhsContractingDimensions,
453
466
ArrayRef<int64_t > lhsRaggedDimensions,
454
467
ArrayRef<int64_t > rhsGroupDimensions) {
455
- // Check that the group sizes has rank=1.
456
- if (rankedGroupSizesType.getRank () != 1 ) {
457
- return emitOptionalError (
458
- location, " expected rank of group_sizes of ragged dot to be 1, got " ,
459
- rankedGroupSizesType.getRank ());
460
- }
461
- auto numGroups = rankedGroupSizesType.getDimSize (0 );
462
-
463
468
// Check that there is exactly one lhs ragged dimension.
464
469
if (lhsRaggedDimensions.size () != 1 ) {
465
470
return emitOptionalError (
@@ -474,6 +479,82 @@ LogicalResult checkRaggedDotConstraints(
474
479
return failure ();
475
480
}
476
481
482
+ enum Mode {
483
+ // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [b,g] -> [b,m,n].
484
+ kNonContracting ,
485
+ // Ragged contracting (k): [b,m,k], [b,k,n], [b,g] -> [g,b,m,n].
486
+ kContracting ,
487
+ // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
488
+ kBatch
489
+ };
490
+ Mode mode;
491
+ if (llvm::is_contained (lhsBatchingDimensions, lhsRaggedDim)) {
492
+ mode = kBatch ;
493
+ } else if (llvm::is_contained (lhsContractingDimensions, lhsRaggedDim)) {
494
+ mode = kContracting ;
495
+ } else {
496
+ mode = kNonContracting ;
497
+ }
498
+
499
+ // Validate the shape of group_sizes.
500
+ {
501
+ // Construct the expected shape [b...,x...,g] of group_sizes.
502
+ SmallVector<int64_t > prefixDims;
503
+ prefixDims.reserve (rankedLhsType.getRank () - 1 );
504
+ prefixDims.insert (prefixDims.end (), lhsBatchingDimensions.begin (),
505
+ lhsBatchingDimensions.end ());
506
+ switch (mode) {
507
+ case kBatch :
508
+ prefixDims.resize (
509
+ std::distance (lhsBatchingDimensions.begin (),
510
+ llvm::find (lhsBatchingDimensions, lhsRaggedDim)));
511
+ break ;
512
+ case kContracting :
513
+ prefixDims.insert (prefixDims.end (), lhsContractingDimensions.begin (),
514
+ llvm::find (lhsContractingDimensions, lhsRaggedDim));
515
+ break ;
516
+ case kNonContracting :
517
+ for (int64_t i = 0 ; i < lhsRaggedDim; ++i) {
518
+ if (!llvm::is_contained (lhsBatchingDimensions, i) &&
519
+ !llvm::is_contained (lhsContractingDimensions, i)) {
520
+ prefixDims.push_back (i);
521
+ }
522
+ }
523
+ break ;
524
+ }
525
+ SmallVector<int64_t > expectedPrefix;
526
+ expectedPrefix.reserve (prefixDims.size ());
527
+ for (const int64_t dim : prefixDims) {
528
+ expectedPrefix.push_back (rankedLhsType.getDimSize (dim));
529
+ }
530
+
531
+ // Validate the actual shape, if it was passed as something other than [g].
532
+ if (rankedGroupSizesType.getRank () != 1 ) {
533
+ if (rankedGroupSizesType.getRank () !=
534
+ static_cast <int64_t >(expectedPrefix.size ()) + 1 ) {
535
+ return emitOptionalError (location, " expected group_sizes to have rank " ,
536
+ expectedPrefix.size () + 1 , " , got " ,
537
+ rankedGroupSizesType.getRank ());
538
+ }
539
+ auto groupSizesShape = rankedGroupSizesType.getShape ();
540
+ if (!std::equal (expectedPrefix.begin (), expectedPrefix.end (),
541
+ groupSizesShape.begin ())) {
542
+ auto nonEmptyShapeStr = [](ArrayRef<int64_t > shape) {
543
+ std::string s = " " ;
544
+ for (size_t i = 0 ; i < shape.size () - 1 ; ++i) {
545
+ s += std::to_string (shape[i]) + " , " ;
546
+ }
547
+ return s + std::to_string (shape.back ());
548
+ };
549
+ return emitOptionalError (
550
+ location, " group_sizes is expected to have shape [" ,
551
+ nonEmptyShapeStr (expectedPrefix), " , " , groupSizesShape.back (),
552
+ " ], got [" , nonEmptyShapeStr (groupSizesShape), " ]" );
553
+ }
554
+ }
555
+ }
556
+ const int64_t numGroups = rankedGroupSizesType.getShape ().back ();
557
+
477
558
// Validate basic properties of the rhs group dimension(s).
478
559
for (auto rhsGroupDim : rhsGroupDimensions) {
479
560
if (failed (hlo::checkDimInBounds (location, rhsGroupDim,
@@ -491,32 +572,34 @@ LogicalResult checkRaggedDotConstraints(
491
572
return failure ();
492
573
}
493
574
494
- if (llvm::is_contained (lhsBatchingDimensions, lhsRaggedDim) ||
495
- llvm::is_contained (lhsContractingDimensions, lhsRaggedDim)) {
496
- // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
497
- // Ragged contracting (k): [b,m,k], [b,k,n], [g] -> [g,b,m,n].
498
- if (!rhsGroupDimensions.empty ()) {
499
- return emitOptionalError (
500
- location,
501
- " There must be zero group dimensions in the rhs when the "
502
- " ragged dimension is batch or contracting." );
503
- }
504
- } else {
505
- // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [g] -> [b,m,n].
506
- if (rhsGroupDimensions.size () != 1 ) {
507
- return emitOptionalError (
508
- location,
509
- " There must be exactly one group dimension in the rhs when the lhs "
510
- " ragged dimension is non-contracting." );
511
- }
512
- // Compare the group dimension size with the number of groups.
513
- const int64_t rhsGroupDim = rhsGroupDimensions[0 ];
514
- if (!hlo::verifyCompatibleDims (numGroups,
515
- rankedRhsType.getDimSize (rhsGroupDim))) {
516
- return emitOptionalError (
517
- location, " group_sizes is expected to have shape=[" ,
518
- rankedRhsType.getDimSize (rhsGroupDim), " ], got [" , numGroups, " ]" );
519
- }
575
+ switch (mode) {
576
+ case kBatch :
577
+ [[fallthrough]];
578
+ case kContracting :
579
+ if (!rhsGroupDimensions.empty ()) {
580
+ return emitOptionalError (
581
+ location,
582
+ " There must be zero group dimensions in the rhs when the "
583
+ " ragged dimension is batch or contracting." );
584
+ }
585
+ break ;
586
+ case kNonContracting :
587
+ if (rhsGroupDimensions.size () != 1 ) {
588
+ return emitOptionalError (
589
+ location,
590
+ " There must be exactly one group dimension in the rhs when the lhs "
591
+ " ragged dimension is non-contracting." );
592
+ }
593
+ // Compare the group dimension size with the number of groups.
594
+ const int64_t rhsGroupDim = rhsGroupDimensions[0 ];
595
+ if (!hlo::verifyCompatibleDims (numGroups,
596
+ rankedRhsType.getDimSize (rhsGroupDim))) {
597
+ return emitOptionalError (
598
+ location,
599
+ " rhs group dimension is expected to have size=" , numGroups,
600
+ " , got " , rankedRhsType.getDimSize (rhsGroupDim));
601
+ }
602
+ break ;
520
603
}
521
604
return success ();
522
605
}
@@ -530,10 +613,10 @@ SmallVector<int64_t> inferRaggedDotOutputDimensions(
530
613
ArrayRef<int64_t > rhsContractingDimensions,
531
614
ArrayRef<int64_t > lhsRaggedDimensions,
532
615
ArrayRef<int64_t > rhsGroupDimensions) {
533
- // Must have already checked that group_sizes is 1-D.
534
- const int64_t numGroups = rankedGroupSizesType.getDimSize (0 );
535
616
// Must have already checked that there is exactly one lhs ragged dim.
536
617
const int64_t lhsRaggedDim = lhsRaggedDimensions[0 ];
618
+ // Must have already checked the shape of group_sizes.
619
+ const int64_t numGroups = rankedGroupSizesType.getShape ().back ();
537
620
538
621
SmallVector<int64_t > dimensions;
539
622
// Add the group dimension to the result shape in case of ragged contracting.
0 commit comments