diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -131,6 +131,8 @@ let results = (outs Shape_SizeType:$result); let assemblyFormat = "attr-dict $value"; + + let hasFolder = 1; } def Shape_FromExtentsOp : Shape_Op<"from_extents", [ @@ -190,6 +192,24 @@ let hasFolder = 1; } +def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Converts a standard index to a shape size"; + let description = [{ + Converts a standard index to a `shape.size`. + This operation and its inverse, `size_to_index`, facilitate index conversion + between the standard and the shape dialect. + }]; + + let arguments = (ins Index:$arg); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $arg"; + + let hasFolder = 1; +} + def Shape_JoinOp : Shape_Op<"join", []> { let summary = "Returns the least general shape.size of its operands"; let description = [{ @@ -288,6 +308,24 @@ let hasFolder = 1; } +def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Casts between index types of the shape and standard dialect"; + let description = [{ + Converts a `shape.size` to a standard index. + This operation and its inverse, `index_to_size`, facilitate index conversion + between the standard and the shape dialect. + }]; + + let arguments = (ins Shape_SizeType:$arg); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $arg"; + + let hasFolder = 1; +} + def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> { let summary = "Returns the value to parent op"; @@ -499,7 +537,6 @@ let assemblyFormat = "$inputs attr-dict"; } - // Canonicalization patterns. #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -249,7 +249,7 @@ return success(); } -OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } +OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } LogicalResult ConstShapeOp::inferReturnTypes(MLIRContext *context, @@ -273,6 +273,29 @@ return success(); } +OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } + +//===----------------------------------------------------------------------===// +// IndexToSizeOp +//===----------------------------------------------------------------------===// + +OpFoldResult IndexToSizeOp::fold(ArrayRef operands) { + + // Constant values of both types, `shape.size` and `index`, are represented as + // `IntegerAttr`s which makes constant folding simple. + if (Attribute arg = operands[0]) + return arg; + return {}; +} + +LogicalResult IndexToSizeOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // FromExtentsOp //===----------------------------------------------------------------------===// @@ -316,6 +339,27 @@ return builder.getI64TensorAttr(type.getShape()); } +//===----------------------------------------------------------------------===// +// SizeToIndexOp +//===----------------------------------------------------------------------===// + +OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { + + // Constant values of both types, `shape.size` and `index`, are represented as + // `IntegerAttr`s which makes constant folding simple. + if (Attribute arg = operands[0]) + return arg; + return {}; +} + +LogicalResult SizeToIndexOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(IndexType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // SplitAtOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -88,6 +88,7 @@ } // ----- + // Basic case. // CHECK-LABEL: func @f() func @f() -> !shape.shape { @@ -106,3 +107,56 @@ %ret = shape.from_extents %e0, %arg0 return %ret : !shape.shape } + +// ----- +// Cast constant size to index and fold it away. +// CHECK-LABEL: func @const_size_to_index +func @const_size_to_index() -> index { + // CHECK-NOT: shape.index_cast + %cs = shape.const_size 123 + // CHECK: constant 123 : index + %ci = shape.size_to_index %cs + return %ci : index +} + +// ----- +// Cast constant index to size and fold it away. +// CHECK-LABEL: func @const_index_to_size +func @const_index_to_size() -> !shape.size { + // CHECK-NOT: index_cast + %ci = constant 123 : index + // CHECK: shape.const_size 123 + %cs = shape.index_to_size %ci + return %cs : !shape.size +} + +// ----- +// Cast constant index to size, then back, and fold it away. +// CHECK-LABEL: func @const_index_to_size_to_index +func @const_index_to_size_to_index() -> index { + // CHECK-NOT: shape.index_cast + %ci0 = constant 123 : index + %cs0 = shape.index_to_size %ci0 + // CHECK: %[[CI:.*]] = constant 123 : index + // CHECK-NEXT: return %[[CI]] : index + %ci1 = shape.size_to_index %cs0 + return %ci1 : index +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_size_to_index +func @nonfoldable_size_to_index(%cs : !shape.size) -> index { + // CHECK: shape.size_to_index + %ci = shape.size_to_index %cs + return %ci : index +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_index_to_size +func @nonfoldable_index_to_size(%ci : index) -> !shape.size { + // CHECK: shape.index_to_size + %cs = shape.index_to_size %ci + return %cs : !shape.size +}