diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -604,7 +604,8 @@ If `index` is negative, it is treated as indexing from the back of the shape. This negative-handling behavior is important when handling unranked shapes, where the positive index is not necessarily knowable due to a - dynamic number of leading dimensions. + dynamic number of leading dimensions. If the result is in extent tensor form + out of bounds indices result in undefined behavior. Examples: - split_at([4,5,6], index=0) -> [], [4,5,6] @@ -623,7 +624,8 @@ let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, Shape_SizeOrIndexType:$index); - let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); + let results = (outs Shape_ShapeOrExtentTensorType:$head, + Shape_ShapeOrExtentTensorType:$tail); let hasFolder = 1; } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -590,6 +590,47 @@ return success(); } +namespace { +class SplitAtOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SplitAtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult SplitAtOpConversion::matchAndRewrite( + SplitAtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // Error conditions are not implemented, only lower if all operands and + // results are extent tensors. + if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, + [](Value v) { return v.getType().isa(); })) + return failure(); + + SplitAtOp::Adaptor transformed(op); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value zero = b.create(0); + Value rank = b.create(transformed.operand(), zero); + + // index < 0 ? index + rank : index + Value originalIndex = transformed.index(); + Value add = b.create(originalIndex, rank); + Value indexIsNegative = + b.create(CmpIPredicate::slt, originalIndex, zero); + Value index = b.create(indexIsNegative, add, originalIndex); + + Value one = b.create(1); + Value head = b.create(transformed.operand(), zero, index, one); + Value tailSize = b.create(rank, index); + Value tail = + b.create(transformed.operand(), index, tailSize, one); + rewriter.replaceOp(op, {head, tail}); + return success(); +} + namespace { class ToExtentTensorOpConversion : public OpConversionPattern { @@ -660,6 +701,7 @@ ReduceOpConverter, ShapeEqOpConverter, ShapeOfOpConversion, + SplitAtOpConversion, ToExtentTensorOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -592,3 +592,23 @@ : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor return } + +// ----- + +// Lower `split_at` +// CHECK-LABEL: @split_at +// CHECK-SAME: %[[SHAPE:.*]]: tensor, %[[INDEX:.*]]: index +func @split_at(%shape: tensor, %index: index) -> (tensor, tensor) { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor + // CHECK-NEXT: %[[POSINDEX:.*]] = addi %[[INDEX]], %[[RANK]] : index + // CHECK-NEXT: %[[ISNEG:.*]] = cmpi slt, %[[INDEX]], %[[C0]] : index + // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index + // CHECK-NEXT: %[[HEAD:.*]] = subtensor %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor to tensor + // CHECK-NEXT: %[[TAIL_SIZE:.*]] = subi %[[RANK]], %[[SELECT]] : index + // CHECK-NEXT: %[[TAIL:.*]] = subtensor %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor to tensor + // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor, tensor + %head, %tail = "shape.split_at"(%shape, %index) : (tensor, index) -> (tensor, tensor) + return %head, %tail : tensor, tensor +}