diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -328,6 +328,41 @@ let hasFolder = 1; } +def Shape_DimOp : Shape_Op<"dim", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "Gets the specified extent from the shape of a shaped input"; + let description = [{ + Gets the extent indexed by `dim` from the shape of the `value` operand. If + the dim is error or out-of-bound then it returns an invalid size if the + return type carries error information else the behavior is undefined. + + This is a convenience op that performs the equivalent of getting the extent + of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`). + }]; + let arguments = (ins AnyShaped:$value, + Shape_SizeOrIndexType:$dim); + let results = (outs Shape_SizeOrIndexType:$extent); + let assemblyFormat = "$value `,` $dim attr-dict `:` type($value) `,` type($dim) `->` " + "type($extent)"; + + let builders = [ + // Builder that allows passing a constant dimension as a simple integer. + OpBuilder<(ins "Value":$value, "int64_t":$dim)> + ]; + + let extraClassDeclaration = [{ + /// Get the `dim` value as integer if it is constant. + Optional getConstantDim(); + + /// Returns when two result types are compatible for this op; method used by + /// InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; + + let hasFolder = 1; + let hasVerifier = 1; +} + def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Gets the specified extent from a shape or extent tensor"; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -322,6 +322,28 @@ return success(); } +namespace { +class DimOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(DimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult +DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further + // lowerings. This can be further optimized if needed to avoid intermediate + // steps. + auto shapeOf = rewriter.create(op.getLoc(), op.getValue()); + rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, + op.getDim()); + return success(); +} + namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -693,6 +715,7 @@ BroadcastOpConverter, ConstShapeOpConverter, ConstSizeOpConversion, + DimOpConverter, IsBroadcastableOpConverter, GetExtentOpConverter, RankOpConverter, diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1064,6 +1064,58 @@ return operands[0]; } +//===----------------------------------------------------------------------===// +// DimOp +//===----------------------------------------------------------------------===// + +Optional DimOp::getConstantDim() { + if (auto constSizeOp = getDim().getDefiningOp()) + return constSizeOp.getValue().getLimitedValue(); + if (auto constantOp = getDim().getDefiningOp()) + return constantOp.getValue().cast().getInt(); + return llvm::None; +} + +OpFoldResult DimOp::fold(ArrayRef operands) { + Type valType = getValue().getType(); + auto valShapedType = valType.dyn_cast(); + if (!valShapedType || !valShapedType.hasRank()) return nullptr; + Optional dim = getConstantDim(); + if (!dim.has_value()) + return nullptr; + if (dim.value() >= valShapedType.getRank()) + return nullptr; + auto extent = valShapedType.getDimSize(*dim); + if (ShapedType::isDynamic(extent)) + return nullptr; + return IntegerAttr::get(IndexType::get(getContext()), extent); +} + +LogicalResult mlir::shape::DimOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + DimOpAdaptor dimOp(operands); + inferredReturnTypes.assign({dimOp.getDim().getType()}); + return success(); +} + +bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + return eachHasOnlyOneOfTypes(l, r); +} + +LogicalResult mlir::shape::DimOp::verify() { + auto st = getValue().getType().cast(); + if (!st.hasRank()) + return success(); + if (auto dim = getConstantDim()) { + if (dim >= st.getRank()) + return emitOpError("index is out of range"); + } + return success(); +} + //===----------------------------------------------------------------------===// // DivOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -60,6 +60,18 @@ // ----- +// Express `shape.dim` as `tensor.dim` when valid. +// CHECK-LABEL: @dim +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index +func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index { + // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> + // CHECK: return %[[RESULT]] : index + %result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index + return %result : index +} + +// ----- + // Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a // `shape_of` operation. // CHECK-LABEL: @get_extent_shape_of