diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1408,7 +1408,9 @@ let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value memrefOrTensor, int64_t index"> + "Value memrefOrTensor, int64_t index">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value memrefOrTensor, Value index">, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/ShapeToStandard/LoweringPatterns.td b/mlir/lib/Conversion/ShapeToStandard/LoweringPatterns.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToStandard/LoweringPatterns.td @@ -0,0 +1,6 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" + +def GetExtentShapeOfConversion : Pat< + (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), + (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx)))>; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -19,6 +19,8 @@ namespace { +#include "LoweringPatterns.inc" + /// Conversion patterns. template class BinaryOpConversion : public OpConversionPattern { @@ -122,6 +124,7 @@ /// Conversion pass. class ConvertShapeToStandardPass : public ConvertShapeToStandardBase { + void runOnOperation() override { // Setup type conversion. MLIRContext &ctx = getContext(); @@ -135,6 +138,9 @@ return typeConverter.isSignatureLegal(op.getType()); }); + // TODO: Remove this when lowering for `shape_of` exists. + target.addLegalOp(); + // Setup conversion patterns. OwningRewritePatternList patterns; populateShapeToStandardConversionPatterns(patterns, &ctx); @@ -157,6 +163,7 @@ BinaryOpConversion, ConstSizeOpConverter, FromExtentTensorOpConversion, + GetExtentShapeOfConversion, IndexToSizeOpConversion, SizeToIndexOpConversion, ToExtentTensorOpConversion>(ctx); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1273,8 +1273,13 @@ Value memrefOrTensor, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); + build(builder, result, memrefOrTensor, indexValue); +} + +void DimOp::build(OpBuilder &builder, OperationState &result, + Value memrefOrTensor, Value index) { auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, memrefOrTensor, indexValue); + build(builder, result, indexTy, memrefOrTensor, index); } Optional DimOp::getConstantIndex() { @@ -1308,7 +1313,7 @@ } OpFoldResult DimOp::fold(ArrayRef operands) { - auto index = operands[1].dyn_cast(); + auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. if (!index) diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -86,3 +86,16 @@ } // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: return %[[C1]] : index + +// ----- + +// Convert `get_extent` when it relies on the outcome of a `shape_of` operation. +// CHECK-LABEL: @get_extent +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index +func @get_extent(%arg : tensor<2x3xf32>, %idx : !shape.size) -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> + // CHECK-DAG: return %[[RESULT]] : index + %shape = shape.shape_of %arg : tensor<2x3xf32> + %result = shape.get_extent %shape, %idx + return %result : !shape.size +}