diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -214,6 +214,25 @@ let hasFolder = 1; } +def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Converts a standard index to a shape size"; + let description = [{ + Converts a standard index to a `shape.size`. + This operation and its inverse, `size_to_index`, facilitate index conversion + between the standard and the shape dialect. + The behavior is undefined for negative indices. + }]; + + let arguments = (ins Index:$arg); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $arg"; + + let hasFolder = 1; +} + def Shape_JoinOp : Shape_Op<"join", []> { let summary = "Returns the least general shape.size of its operands"; let description = [{ @@ -312,6 +331,25 @@ let hasFolder = 1; } +def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Casts between index types of the shape and standard dialect"; + let description = [{ + Converts a `shape.size` to a standard index. + This operation and its inverse, `index_to_size`, facilitate index conversion + between the standard and the shape dialect. + The behavior is undefined for unknown and invalid arguments. + }]; + + let arguments = (ins Shape_SizeType:$arg); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $arg"; + + let hasFolder = 1; +} + def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> { let summary = "Returns the value to parent op"; @@ -523,7 +561,6 @@ let assemblyFormat = "$inputs attr-dict"; } - // Canonicalization patterns. #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -249,7 +249,7 @@ return success(); } -OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } +OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } //===----------------------------------------------------------------------===// // ConstSizeOp @@ -267,6 +267,26 @@ OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } //===----------------------------------------------------------------------===// +// IndexToSizeOp +//===----------------------------------------------------------------------===// + +OpFoldResult IndexToSizeOp::fold(ArrayRef operands) { + // Constant values of both types, `shape.size` and `index`, are represented as + // `IntegerAttr`s which makes constant folding simple. + if (Attribute arg = operands[0]) + return arg; + return {}; +} + +LogicalResult IndexToSizeOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + +//===----------------------------------------------------------------------===// // FromExtentsOp //===----------------------------------------------------------------------===// @@ -334,6 +354,26 @@ } //===----------------------------------------------------------------------===// +// SizeToIndexOp +//===----------------------------------------------------------------------===// + +OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { + // Constant values of both types, `shape.size` and `index`, are represented as + // `IntegerAttr`s which makes constant folding simple. + if (Attribute arg = operands[0]) + return arg; + return {}; +} + +LogicalResult SizeToIndexOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(IndexType::get(context)); + return success(); +} + +//===----------------------------------------------------------------------===// // SplitAtOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -108,6 +108,60 @@ } // ----- +// Cast constant size to index and fold it away. +// CHECK-LABEL: func @const_size_to_index +func @const_size_to_index() -> index { + // CHECK-NOT: shape.index_cast + %cs = shape.const_size 123 + // CHECK: constant 123 : index + %ci = shape.size_to_index %cs + return %ci : index +} + +// ----- +// Cast constant index to size and fold it away. +// CHECK-LABEL: func @const_index_to_size +func @const_index_to_size() -> !shape.size { + // CHECK-NOT: index_cast + %ci = constant 123 : index + // CHECK: shape.const_size 123 + %cs = shape.index_to_size %ci + return %cs : !shape.size +} + +// ----- +// Cast constant index to size, then back, and fold it away. +// CHECK-LABEL: func @const_index_to_size_to_index +func @const_index_to_size_to_index() -> index { + // CHECK-NOT: shape.index_cast + %ci0 = constant 123 : index + %cs0 = shape.index_to_size %ci0 + // CHECK: %[[CI:.*]] = constant 123 : index + // CHECK-NEXT: return %[[CI]] : index + %ci1 = shape.size_to_index %cs0 + return %ci1 : index +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_size_to_index +func @nonfoldable_size_to_index(%cs : !shape.size) -> index { + // CHECK: shape.size_to_index + %ci = shape.size_to_index %cs + return %ci : index +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_index_to_size +func @nonfoldable_index_to_size(%ci : index) -> !shape.size { + // CHECK: shape.index_to_size + %cs = shape.index_to_size %ci + return %cs : !shape.size +} + +// ----- + // Canonicalization of shape.get_extent // Basic folding.