diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -16,6 +16,8 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffects.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffects.td" // TODO(jpienaar): Move to base. @@ -168,17 +169,37 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Shape_CreateShapeOp : Shape_Op<"create_shape", []> { - let summary = "Creates a shape descriptor from a tensor"; +def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { + let summary = "Creates a shape from a tensor of extents"; let description = [{ - Creates a shape from a 1D integral tensor. The rank equals the number of - elements in the tensor, and extent matches the values of the elements. + Creates a shape from a 1D integral tensor of extents. The rank of the + resulting shape equals the number of elements in the tensor, and the + extents match the values of the elements. }]; let arguments = (ins I32Tensor:$input); let results = (outs Shape_ShapeType:$result); } +def Shape_ToExtentTensorOp : Shape_Op<"to_tensor", []> { + let summary = "Creates a dimension tensor from a shape"; + // TODO: Think more about the error situation. Perhaps factor out the + // error detection into a separate op so downstream consumers can control + // their error behavior. Then this op would assume that the input has + // been properly checked to not be an error (and could thus be a + // NoSideEffect op). + let description = [{ + Converts a shape to a 1D integral tensor of extents. The number of elements + in the tensor equals the rank of the shape, and the elements equal the + extents of the shape. + + If the shape represents an error, then this op currently aborts the program. + }]; + + let arguments = (ins Shape_ShapeType:$input); + let results = (outs I32Tensor:$result); +} + def Shape_JoinOp : Shape_Op<"join", []> { let summary = "Returns the least general shape.size of its operands"; let description = [{ @@ -299,4 +320,50 @@ let results = (outs Shape_ShapeOrSizeType:$output); } +def Shape_SplitAtOp : Shape_Op<"split_at", + [DeclareOpInterfaceMethods]> { + let summary = "Splits a shape at a given index."; + let description = [{ + Splits a shape at a given dimension `index`, returning two shapes. + If `index` is negative, it is treated as indexing from the back of the + shape. This negative-handling behavior is important when handling unranked + shapes, where the positive index is not necessarily knowable due to a + dynamic number of leading dimensions. + + Examples: + - split_at([4,5,6], index=0) -> [], [4,5,6] + - split_at([4,5,6], index=1) -> [4], [5,6] + - split_at([4,5,6], index=2) -> [4,5], [6] + - split_at([4,5,6], index=3) -> [4,5,6], [] + - split_at([4,5,6], index=4) -> error + - split_at([4,5,6], index=-1) -> [4,5], [6] + - split_at([4,5,6], index=-2) -> [4], [5,6] + - split_at([4,5,6], index=-3) -> [], [4,5,6] + - split_at([4,5,6], index=-4) -> error + + Requires: + - `index` is in the range [-rank(operand),rank(operand)] + }]; + + let arguments = (ins Shape_ShapeType:$operand, I32:$index); + let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); +} + +def Shape_ConcatOp : Shape_Op<"concat", + [DeclareOpInterfaceMethods]> { + let summary = "Concatenates two shapes."; + let description = [{ + Creates a shape whose dimensions consist of first the dimensions from `lhs` + followed by the dimensions of `rhs`. + + Example: + concat([2,3], [4,5]) -> [2,3,4,5] + concat([], []) -> [] + concat([], [4,5,6]) -> [4,5,6] + }]; + + let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); + let results = (outs Shape_ShapeType:$result); +} + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -9,6 +9,7 @@ ) target_link_libraries(MLIRShape PUBLIC + MLIRInferTypeOpInterface MLIRIR MLIRSideEffects LLVMSupport diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -106,6 +106,33 @@ static LogicalResult verify(ConstantOp &op) { return success(); } +//===----------------------------------------------------------------------===// +// SplitAtOp +//===----------------------------------------------------------------------===// + +LogicalResult SplitAtOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto shapeType = ShapeType::get(context); + inferredReturnTypes.push_back(shapeType); + inferredReturnTypes.push_back(shapeType); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +LogicalResult ConcatOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto shapeType = ShapeType::get(context); + inferredReturnTypes.push_back(shapeType); + return success(); +} + namespace mlir { namespace shape {