diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -185,24 +185,26 @@ let summary = "Gets the specified extent from a shape"; let description = [{ Gets the extent indexed by `dim` from `shape`. - - If the shape is an error, it returns an error size. + If the shape is an error, its behavior is undefined. }]; let arguments = (ins Shape_ShapeType:$shape, - Confined:$dim + Shape_SizeType:$dim ); let results = (outs Shape_SizeType:$extent); let assemblyFormat = "$shape `,` $dim attr-dict"; let builders = [ - // Builder that allows passing a simple integer instead of an IntegerAttr. - OpBuilder< - [{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}], - [{build(builder, result, shape, builder.getI64IntegerAttr(dim));}] - > + // Builder that allows passing a constant dimension as a simple integer. + OpBuilder<"OpBuilder &builder, OperationState &result, Value shape, " + "int64_t dim"> ]; + let extraClassDeclaration = [{ + /// Get the `dim` value as integer if it is constant. + Optional getConstantDim(); + }]; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -388,15 +388,23 @@ // GetExtentOp //===----------------------------------------------------------------------===// +Optional GetExtentOp::getConstantDim() { + auto constSizeOp = dyn_cast(dim().getDefiningOp()); + if (!constSizeOp) + return {}; + return constSizeOp.value().getLimitedValue(); +} + OpFoldResult GetExtentOp::fold(ArrayRef operands) { auto elements = operands[0].dyn_cast_or_null(); if (!elements) return nullptr; - uint64_t dimToGet = dim().getLimitedValue(); - // TODO: Constant fold this to some kind of constant error. - if (dimToGet >= (uint64_t)elements.getNumElements()) + Optional dim = getConstantDim(); + if (!dim.hasValue()) + return nullptr; + if (dim.getValue() >= (uint64_t)elements.getNumElements()) return nullptr; - return elements.getValue({dimToGet}); + return elements.getValue({dim.getValue()}); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -183,33 +183,35 @@ } // ----- - -// Canonicalization of shape.get_extent - // Basic folding. // CHECK-LABEL: func @basic func @basic() -> !shape.size { // CHECK: shape.const_size 2 %0 = shape.const_shape [0, 1, 2] - %1 = shape.get_extent %0, 2 + %c2 = shape.const_size 2 + %1 = shape.get_extent %0, %c2 return %1 : !shape.size } +// ----- // Should not fold. // CHECK-LABEL: func @out_of_bounds func @out_of_bounds() -> !shape.size { // CHECK: shape.const_shape // CHECK: shape.get_extent %0 = shape.const_shape [0, 1, 2] - %1 = shape.get_extent %0, 3 + %c3 = shape.const_size 3 + %1 = shape.get_extent %0, %c3 return %1 : !shape.size } +// ----- // Should not fold. // CHECK-LABEL: func @not_const func @not_const(%arg0: !shape.shape) -> !shape.size { // CHECK: shape.get_extent - %0 = shape.get_extent %arg0, 3 + %c3 = shape.const_size 3 + %0 = shape.get_extent %arg0, %c3 return %0 : !shape.size }