diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -212,6 +212,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1408,7 +1408,9 @@ let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value memrefOrTensor, int64_t index"> + "Value memrefOrTensor, int64_t index">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value memrefOrTensor, Value index">, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -405,6 +406,11 @@ // GetExtentOp //===----------------------------------------------------------------------===// +void GetExtentOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + Optional GetExtentOp::getConstantDim() { if (auto constSizeOp = dim().getDefiningOp()) { return constSizeOp.value().getLimitedValue(); diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -1,4 +1,5 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" def EqualBinaryOperands : Constraint>; @@ -16,3 +17,7 @@ def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), (Shape_ConstWitnessOp ConstBoolAttrTrue), [(AllInputShapesEq $shapes)]>; + +def GetExtentShapeOfCanonicalization : Pat< + (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), + (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx)))>; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1268,8 +1268,13 @@ Value memrefOrTensor, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); + build(builder, result, memrefOrTensor, indexValue); +} + +void DimOp::build(OpBuilder &builder, OperationState &result, + Value memrefOrTensor, Value index) { auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, memrefOrTensor, indexValue); + build(builder, result, indexTy, memrefOrTensor, index); } Optional DimOp::getConstantIndex() { @@ -1303,7 +1308,7 @@ } OpFoldResult DimOp::fold(ArrayRef operands) { - auto index = operands[1].dyn_cast(); + auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. if (!index) diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -442,3 +442,18 @@ "consume.witness"(%0) : (!shape.witness) -> () return } + +// ----- + +// Canonicalize `get_extent` not to rely on `shape_of`. +// CHECK-LABEL: @get_extent +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>, %[[IDX:.*]]: !shape.size) -> !shape.size +func @get_extent(%arg : tensor<2x?xf32>, %idx : !shape.size) -> !shape.size { + // CHECK-NEXT: %[[IDX_AS_INDEX:.*]] = shape.size_to_index %[[IDX]] + // CHECK-NEXT: %[[RESULT_AS_INDEX:.*]] = dim %[[ARG]], %[[IDX_AS_INDEX]] : tensor<2x?xf32> + // CHECK-NEXT: %[[RESULT:.*]] = shape.index_to_size %[[RESULT_AS_INDEX]] + // CHECK-NEXT: return %[[RESULT]] : !shape.size + %shape = shape.shape_of %arg : tensor<2x?xf32> + %result = shape.get_extent %shape, %idx + return %result : !shape.size +}