diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -185,24 +185,26 @@ let summary = "Gets the specified extent from a shape"; let description = [{ Gets the extent indexed by `dim` from `shape`. - If the shape is an error, it returns an error size. }]; let arguments = (ins Shape_ShapeType:$shape, - Confined:$dim + Shape_SizeType:$dim ); let results = (outs Shape_SizeType:$extent); let assemblyFormat = "$shape `,` $dim attr-dict"; let builders = [ - // Builder that allows passing a simple integer instead of an IntegerAttr. - OpBuilder< - [{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}], - [{build(builder, result, shape, builder.getI64IntegerAttr(dim));}] - > + // Builder that allows passing a constant dimension as a simple integer. + OpBuilder<"OpBuilder &builder, OperationState &result, Value shape, " + "int64_t dim"> ]; + let extraClassDeclaration = [{ + /// Get the `dim` value as integer if it is constant. + Optional getConstantDim(); + }]; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -337,15 +337,31 @@ // GetExtentOp //===----------------------------------------------------------------------===// +Optional GetExtentOp::getConstantDim() { + if (auto constSizeOp = dim().getDefiningOp()) { + return constSizeOp.value().getLimitedValue(); + } + return llvm::None; +} + OpFoldResult GetExtentOp::fold(ArrayRef operands) { auto elements = operands[0].dyn_cast_or_null(); if (!elements) return nullptr; - uint64_t dimToGet = dim().getLimitedValue(); - // TODO: Constant fold this to some kind of constant error. - if (dimToGet >= (uint64_t)elements.getNumElements()) + Optional dim = getConstantDim(); + if (!dim.hasValue()) return nullptr; - return elements.getValue({dimToGet}); + if (dim.getValue() >= elements.getNumElements()) + return nullptr; + return elements.getValue({(uint64_t)dim.getValue()}); +} + +void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, + int64_t dim) { + auto loc = result.location; + auto dimAttr = builder.getIndexAttr(dim); + Value dimValue = builder.create(loc, dimAttr); + build(builder, result, shape, dimValue); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail -// ----- // CHECK-LABEL: func @f func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape { // CHECK: shape.const_shape [2, 3, 4] @@ -9,6 +8,7 @@ } // ----- + // Basic case. // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { @@ -22,6 +22,7 @@ } // ----- + // Negative split point. // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { @@ -34,6 +35,7 @@ } // ----- + // Out of range split point. No folding. // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { @@ -45,6 +47,7 @@ } // ----- + // Basic case. // CHECK-LABEL: func @f func @f() -> !shape.shape { @@ -56,6 +59,7 @@ } // ----- + // Incompatible shapes. No folding. // CHECK-LABEL: func @f func @f() -> !shape.shape { @@ -67,6 +71,7 @@ } // ----- + // Basic case. // CHECK-LABEL: func @f func @f() -> !shape.shape { @@ -78,6 +83,7 @@ } // ----- + // Basic case. // CHECK-LABEL: func @f func @f() -> tensor<2xindex> { @@ -88,6 +94,7 @@ } // ----- + // Basic case. // CHECK-LABEL: func @f() func @f() -> !shape.shape { @@ -99,6 +106,8 @@ return %ret : !shape.shape } +// ----- + // CHECK-LABEL: func @no_fold func @no_fold(%arg0: index) -> !shape.shape { // CHECK-NOT: shape.const_shape @@ -108,6 +117,7 @@ } // ----- + // Cast constant size to index and fold it away. // CHECK-LABEL: func @const_size_to_index func @const_size_to_index() -> index { @@ -119,6 +129,7 @@ } // ----- + // Cast constant index to size and fold it away. // CHECK-LABEL: func @const_index_to_size func @const_index_to_size() -> !shape.size { @@ -130,6 +141,7 @@ } // ----- + // Cast constant index to size, then back, and fold it away. // CHECK-LABEL: func @const_index_to_size_to_index func @const_index_to_size_to_index() -> index { @@ -143,6 +155,7 @@ } // ----- + // No folding. // CHECK-LABEL: func @nonfoldable_size_to_index func @nonfoldable_size_to_index(%cs : !shape.size) -> index { @@ -152,6 +165,7 @@ } // ----- + // No folding. // CHECK-LABEL: func @nonfoldable_index_to_size func @nonfoldable_index_to_size(%ci : index) -> !shape.size { @@ -161,6 +175,7 @@ } // ----- + // Fold number of elements computation. // CHECK-LABEL: func @num_elements func @num_elements() -> !shape.size { @@ -174,6 +189,7 @@ } // ----- + // No folding. // CHECK-LABEL: func @nonfoldable_num_elements func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { @@ -184,36 +200,42 @@ // ----- -// Canonicalization of shape.get_extent - // Basic folding. // CHECK-LABEL: func @basic func @basic() -> !shape.size { // CHECK: shape.const_size 2 %0 = shape.const_shape [0, 1, 2] - %1 = shape.get_extent %0, 2 + %c2 = shape.const_size 2 + %1 = shape.get_extent %0, %c2 return %1 : !shape.size } +// ----- + // Should not fold. // CHECK-LABEL: func @out_of_bounds func @out_of_bounds() -> !shape.size { // CHECK: shape.const_shape // CHECK: shape.get_extent %0 = shape.const_shape [0, 1, 2] - %1 = shape.get_extent %0, 3 + %c3 = shape.const_size 3 + %1 = shape.get_extent %0, %c3 return %1 : !shape.size } +// ----- + // Should not fold. // CHECK-LABEL: func @not_const func @not_const(%arg0: !shape.shape) -> !shape.size { // CHECK: shape.get_extent - %0 = shape.get_extent %arg0, 3 + %c3 = shape.const_size 3 + %0 = shape.get_extent %arg0, %c3 return %0 : !shape.size } // ----- + // assuming_all with known passing witnesses can be folded // CHECK-LABEL: func @f func @f() { @@ -229,6 +251,7 @@ } // ----- + // assuming_all should not be removed if not all witnesses are statically passing. // // Additionally check that the attribute is moved to the end as this op is @@ -247,6 +270,7 @@ } // ----- + // any can be replaced with a constant input if it has one. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { @@ -259,6 +283,7 @@ // ----- + // Folding of any with partially constant operands is not yet implemented. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { @@ -269,6 +294,7 @@ } // ----- + // assuming with a known passing witness can be removed // CHECK-LABEL: func @f func @f() { @@ -285,6 +311,7 @@ } // ----- + // assuming without a known passing passing witness cannot be removed // CHECK-LABEL: func @f func @f() {