diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2402,6 +2402,28 @@ } }; +struct RankOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Type operandType = cast(op).memrefOrTensor().getType(); + if (auto unrankedMemRefType = operandType.dyn_cast()) { + UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); + rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); + return success(); + } + if (auto rankedMemRefType = operandType.dyn_cast()) { + rewriter.replaceOp( + op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); + return success(); + } + return failure(); + } +}; + // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. @@ -3272,6 +3294,7 @@ DimOpLowering, LoadOpLowering, MemRefCastOpLowering, + RankOpLowering, StoreOpLowering, SubViewOpLowering, ViewOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1291,3 +1291,26 @@ func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> { return %arg0 : memref<32xindex> } + +// ----- + +// CHECK-LABEL: func @rank_of_unranked +// CHECK32-LABEL: func @rank_of_unranked +func @rank_of_unranked(%unranked: memref<*xi32>) { + %rank = rank %unranked : memref<*xi32> + return +} +// CHECK-NEXT: llvm.mlir.undef +// CHECK-NEXT: llvm.insertvalue +// CHECK-NEXT: llvm.insertvalue +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }"> +// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }"> + +// CHECK-LABEL: func @rank_of_ranked +// CHECK32-LABEL: func @rank_of_ranked +func @rank_of_ranked(%ranked: memref) { + %rank = rank %ranked : memref + return +} +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK32: llvm.mlir.constant(1 : index) : !llvm.i32