diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -131,6 +131,7 @@ let results = (outs Shape_SizeType:$result); let assemblyFormat = "attr-dict $value"; + let hasFolder = 1; } def Shape_FromExtentsOp : Shape_Op<"from_extents", [ @@ -190,6 +191,37 @@ let hasFolder = 1; } +def Shape_GetExtentOp : Shape_Op<"get_extent", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "Gets the specified extent from a shape"; + let description = [{ + Gets the extent indexed by `dim` from `shape`. + + If the shape is an error, it returns an error size. + }]; + let arguments = (ins + Shape_ShapeType:$shape, + Confined:$dim + ); + let results = (outs Shape_SizeType:$extent); + let assemblyFormat = "$shape `,` $dim attr-dict"; + + let builders = [ + // Builder that allows passing a simple integer instead of an IntegerAttr. + OpBuilder< + [{ + OpBuilder &builder, OperationState &result, + Value shape, int64_t dim + }], + [{ + build(builder, result, shape, builder.getI64IntegerAttr(dim)); + }] + > + ]; + + let hasFolder = 1; +} + def Shape_JoinOp : Shape_Op<"join", []> { let summary = "Returns the least general shape.size of its operands"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -245,6 +245,8 @@ return success(); } +OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } + //===----------------------------------------------------------------------===// // FromExtentsOp //===----------------------------------------------------------------------===// @@ -268,6 +270,37 @@ } //===----------------------------------------------------------------------===// +// GetExtentOp +//===----------------------------------------------------------------------===// + +LogicalResult +GetExtentOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + +OpFoldResult GetExtentOp::fold(ArrayRef operands) { + auto elements = operands[0].dyn_cast_or_null(); + if (!elements) + return nullptr; + uint64_t dimToGet = dim().getLimitedValue(); + // TODO: Constant fold this to some kind of constant error. + if (dimToGet >= (uint64_t)elements.getNumElements()) + return nullptr; + // This is a little inconvenient because getValue returns an IntegerAttr + // that is not of IndexType, but the result here needs to be of + // IndexType. + // TODO: Make ConstShapeOp hold an tensor of index instead of i64. + Builder builder(getContext()); + return builder.getIntegerAttr( + builder.getIndexType(), + elements.getValue({dimToGet}).getInt()); +} + +//===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -106,3 +106,33 @@ %ret = shape.from_extents %e0, %arg0 return %ret : !shape.shape } + +// ----- +// Canonicalization of shape.get_extent + +// Basic folding. +// CHECK-LABEL: func @basic +func @basic() -> !shape.size { + // CHECK: shape.const_size 2 + %0 = shape.const_shape [0, 1, 2] + %1 = shape.get_extent %0, 2 + return %1 : !shape.size +} + +// Should not fold. +// CHECK-LABEL: func @out_of_bounds +func @out_of_bounds() -> !shape.size { + // CHECK: shape.const_shape + // CHECK: shape.get_extent + %0 = shape.const_shape [0, 1, 2] + %1 = shape.get_extent %0, 3 + return %1 : !shape.size +} + +// Should not fold. +// CHECK-LABEL: func @not_const +func @not_const(%arg0: !shape.shape) -> !shape.size { + // CHECK: shape.get_extent + %0 = shape.get_extent %arg0, 3 + return %0 : !shape.size +}