diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -117,9 +117,9 @@ [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "Creates a constant of !shape.size type."; + let summary = "Creates a constant of type `!shape.size`"; let description = [{ - Creates a !shape.size type representing the constant size given by `value`. + Creates a `!shape.size` type representing the constant size given by `value`. ```mlir %x = shape.const_size 10 @@ -130,8 +130,49 @@ let results = (outs Shape_SizeType:$result); let assemblyFormat = "attr-dict $value"; + + let hasFolder = 1; +} + +def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Converts a standard index to a shape size"; + let description = [{ + Converts a standard index to a `shape.size`. + This operation and its inverse, `size_to_index`, facilitate index conversion + between the standard and the shape dialect. + Expects an `index` and returns a `shape.size` value. + }]; + + let arguments = (ins Index:$arg); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $arg"; + + let hasFolder = 1; +} + +def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Casts between index types of the shape and standard dialect"; + let description = [{ + Converts a `shape.size` to a standard index. + This operation and its inverse, `index_to_size`, facilitate index conversion + between the standard and the shape dialect. + Expects a `shape.size` and returns an `index` value. + }]; + + let arguments = (ins Shape_SizeType:$arg); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $arg"; + + let hasFolder = 1; } + def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { let summary = "Creates a shape from a tensor of extents"; let description = [{ @@ -140,7 +181,7 @@ extents match the values of the elements. }]; - let arguments = (ins I32Tensor:$input); + let arguments = (ins IndexTensor:$input); let results = (outs Shape_ShapeType:$result); } @@ -168,10 +209,10 @@ def Shape_JoinOp : Shape_Op<"join", []> { let summary = "Returns the least general shape.size of its operands"; let description = [{ - An operation that computes the least general shape of input operands. This - effectively asserts that corresponding static dimensions are equal. The - behavior is to match each element of the `shape.type` and propagate the most - restrictive information, returning an invalid shape if there are + An operation that computes the least general shape of input operands. + This effectively asserts that corresponding static dimensions are equal. + The behavior is to match each element of the `shape.shape` and propagate the + most restrictive information, returning an invalid shape if there are contradictory requirements. E.g., using pseudo code ``` @@ -189,7 +230,7 @@ used to return an error to the user upon mismatch of dimensions. ```mlir - %c = shape.join %a, %b, error="" : !shape.type + %c = shape.join %a, %b, error="" : !shape.shape ``` }]; @@ -210,6 +251,26 @@ let results = (outs Shape_SizeType:$result); } +def Shape_NumElementsOp : Shape_Op<"num_elements", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + + let summary = "Returns the number of elements for a given shape"; + let description = [{ + Returns the number of elements for a given shape which is the product of its + dimensions. + A tensor of the given shape will hold this many elements. + Expects a `shape.shape` and returns a `shape.size` value. + }]; + + let arguments = (ins Shape_ShapeType:$shape); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $shape"; + + let hasFolder = 1; +} + def Shape_ReduceOp : Shape_Op<"reduce", []> { let summary = "Returns an expression reduced over a shape"; let description = [{ @@ -230,14 +291,14 @@ number of elements ```mlir - func @shape_num_elements(%shape : !shape.type) -> !shape.size { + func @shape_num_elements(%shape : !shape.shape) -> !shape.size { %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size %1 = "shape.reduce"(%shape, %0) ( { ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): %acc = "shape.mul"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size shape.yield %acc : !shape.size - }) : (!shape.type, !shape.size) -> (!shape.size) + }) : (!shape.shape, !shape.size) -> (!shape.size) return %1 : !shape.size } ``` diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -396,6 +396,7 @@ //===----------------------------------------------------------------------===// // RankedTensorType +//===----------------------------------------------------------------------===// /// Ranked tensor types represent multi-dimensional arrays that have a shape /// with a fixed number of dimensions. Each shape element can be a non-negative diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -136,6 +136,34 @@ return builder.getI64TensorAttr(resultShape); } +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +LogicalResult +ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto shapeType = ShapeType::get(context); + inferredReturnTypes.push_back(shapeType); + return success(); +} + +OpFoldResult ConcatOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) + return nullptr; + auto lhsShape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); + SmallVector resultShape; + resultShape.append(lhsShape.begin(), lhsShape.end()); + resultShape.append(rhsShape.begin(), rhsShape.end()); + Builder builder(getContext()); + return builder.getI64TensorAttr(resultShape); +} + //===----------------------------------------------------------------------===// // ConstShapeOp //===----------------------------------------------------------------------===// @@ -177,7 +205,7 @@ return success(); } -OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } +OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } LogicalResult ConstShapeOp::inferReturnTypes(MLIRContext *context, @@ -201,6 +229,58 @@ return success(); } +OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } + +//===----------------------------------------------------------------------===// +// IndexToSizeOp +//===----------------------------------------------------------------------===// + +OpFoldResult IndexToSizeOp::fold(ArrayRef operands) { + + // Fold only when argument constant. + Attribute arg = operands[0]; + if (!arg) + return {}; + + // Constant values of both types, `shape.size` and `index`, are represented as + // `IntegerAttr`s which makes constant folding simple. + return arg; +} + +LogicalResult IndexToSizeOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + +//===----------------------------------------------------------------------===// +// NumElementsOp +//===----------------------------------------------------------------------===// + +OpFoldResult NumElementsOp::fold(ArrayRef operands) { + + // Fold only when argument constant. + Attribute shape = operands[0]; + if (!shape) + return {}; + + APInt product(64, 1); + for (auto value : shape.cast()) + product *= value; + Builder builder(getContext()); + return builder.getIndexAttr(product.getLimitedValue()); +} + +LogicalResult NumElementsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -222,6 +302,30 @@ return builder.getI64TensorAttr(type.getShape()); } +//===----------------------------------------------------------------------===// +// SizeToIndexOp +//===----------------------------------------------------------------------===// + +OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { + + // Fold only when argument constant. + Attribute arg = operands[0]; + if (!arg) + return {}; + + // Constant values of both types, `shape.size` and `index`, are represented as + // `IntegerAttr`s which makes constant folding simple. + return arg; +} + +LogicalResult SizeToIndexOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(IndexType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // SplitAtOp //===----------------------------------------------------------------------===// @@ -258,34 +362,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// ConcatOp -//===----------------------------------------------------------------------===// - -LogicalResult -ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, - ValueRange operands, DictionaryAttr attributes, - RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto shapeType = ShapeType::get(context); - inferredReturnTypes.push_back(shapeType); - return success(); -} - -OpFoldResult ConcatOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) - return nullptr; - auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - SmallVector resultShape; - resultShape.append(lhsShape.begin(), lhsShape.end()); - resultShape.append(rhsShape.begin(), rhsShape.end()); - Builder builder(getContext()); - return builder.getI64TensorAttr(resultShape); -} - //===----------------------------------------------------------------------===// // ToExtentTensorOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -86,3 +86,77 @@ %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> return %0 : tensor<2xindex> } + +// ----- +// Cast constant size to index and fold it away. +// CHECK-LABEL: func @const_size_to_index +func @const_size_to_index() -> index { + // CHECK-NOT: shape.index_cast + %cs = shape.const_size 123 + // CHECK: constant 123 : index + %ci = shape.size_to_index %cs + return %ci : index +} + +// ----- +// Cast constant index to size and fold it away. +// CHECK-LABEL: func @const_index_to_size +func @const_index_to_size() -> !shape.size { + // CHECK-NOT: index_cast + %ci = constant 123 : index + // CHECK: shape.const_size 123 + %cs = shape.index_to_size %ci + return %cs : !shape.size +} + +// ----- +// Cast constant index to size, then back, and fold it away. +// CHECK-LABEL: func @const_index_to_size_to_index +func @const_index_to_size_to_index() -> index { + // CHECK-NOT: shape.index_cast + %ci0 = constant 123 : index + %cs0 = shape.index_to_size %ci0 + // CHECK: constant 123 : index + %ci1 = shape.size_to_index %cs0 + return %ci1 : index +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_size_to_index +func @nonfoldable_size_to_index(%cs : !shape.size) -> index { + // CHECK: %[[CI:.*]] = shape.size_to_index %[[CS:.*]] + %ci = shape.size_to_index %cs + return %ci : index +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_index_to_size +func @nonfoldable_index_to_size(%ci : index) -> !shape.size { + // CHECK: %[[CS:.*]] = shape.index_to_size %[[CI:.*]] + %cs = shape.index_to_size %ci + return %cs : !shape.size +} + +// ----- +// Fold number of elements computation. +// CHECK-LABEL: func @num_elements +func @num_elements() -> !shape.size { + // CHECK-NOT: shape.const_shape + %shape = shape.const_shape [4, 5, 6] + // CHECK-NOT: shape.num_elements + %num_elements = shape.num_elements %shape + // CHECK: %[[NUM:.*]] = shape.const_size 120 + // CHECK-NEXT: return %[[NUM]] : !shape.size + return %num_elements : !shape.size +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_num_elements +func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { + // CHECK-NOT: shape.const_{{.*}} + %num_elements = shape.num_elements %shape + return %num_elements : !shape.size +}