diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -106,6 +106,30 @@ let assemblyFormat = "attr-dict $value"; } +def Shape_FromIndexOp : Shape_Op<"from_index", []> { + let summary = "Creates a shape size from a standard index"; + let description = [{ + Converts a standard index to a shape size. + This is to convert between types of the standard dialect and types of the + shape dialect. + }]; + + let arguments = (ins Index:$input); + let results = (outs Shape_SizeType:$result); +} + +def Shape_ToIndexOp : Shape_Op<"to_index", []> { + let summary = "Creates a standard index from a shape size"; + let description = [{ + Converts a shape size to a standard index. + This is to convert between types of the standard dialect and types of the + shape dialect. + }]; + + let arguments = (ins Shape_SizeType:$input); + let results = (outs Index:$result); +} + def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { let summary = "Creates a shape from a tensor of extents"; let description = [{ @@ -114,7 +138,7 @@ extents match the values of the elements. }]; - let arguments = (ins I32Tensor:$input); + let arguments = (ins IndexTensor:$input); let results = (outs Shape_ShapeType:$result); } @@ -204,14 +228,14 @@ number of elements ```mlir - func @shape_num_elements(%shape : !shape.type) -> !shape.size { + func @shape_num_elements(%shape : !shape.shape) -> !shape.size { %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size %1 = "shape.reduce"(%shape, %0) ( { ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): %acc = "shape.mul"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size shape.yield %acc : !shape.size - }) : (!shape.type, !shape.size) -> (!shape.size) + }) : (!shape.shape, !shape.size) -> (!shape.size) return %1 : !shape.size } ``` @@ -225,6 +249,18 @@ let regions = (region SizedRegion<1>:$body); } +def Shape_NumElementsOp : Shape_Op<"num_elements", []> { + let summary = "Returns the number of elements for a given shape"; + let description = [{ + Returns the number of elements for a given shape which is the product of its + dimensions. + A tensor of the given shape will hold this many elements. + }]; + + let arguments = (ins Shape_ShapeType:$shape); + let results = (outs Shape_SizeType:$result); +} + def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Returns shape of a value or shaped type operand"; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -69,6 +69,8 @@ IndexType getIndexType(); IntegerType getI1Type(); + IntegerType getI32Type(); + IntegerType getI64Type(); IntegerType getIntegerType(unsigned width); IntegerType getIntegerType(unsigned width, bool isSigned); FunctionType getFunctionType(ArrayRef inputs, ArrayRef results); diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -396,6 +396,7 @@ //===----------------------------------------------------------------------===// // RankedTensorType +//===----------------------------------------------------------------------===// /// Ranked tensor types represent multi-dimensional arrays that have a shape /// with a fixed number of dimensions. Each shape element can be a non-negative diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -54,6 +54,10 @@ IntegerType Builder::getI1Type() { return IntegerType::get(1, context); } +IntegerType Builder::getI32Type() { return IntegerType::get(32, context); } + +IntegerType Builder::getI64Type() { return IntegerType::get(64, context); } + IntegerType Builder::getIntegerType(unsigned width) { return IntegerType::get(width, context); }