diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2375,30 +2375,85 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); - auto loc = op->getLoc(); + Type operandType = dimOp.memrefOrTensor().getType(); + if (operandType.isa()) { + rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp, + operands, rewriter)}); + + return success(); + } + if (operandType.isa()) { + rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp, + operands, rewriter)}); + return success(); + } + return failure(); + } + +private: + Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); + auto unrankedMemRefType = operandType.cast(); + auto scalarMemRefType = + MemRefType::get({}, unrankedMemRefType.getElementType()); + unsigned addressSpace = unrankedMemRefType.getMemorySpace(); + + // extract pointer to the underlying ranked descriptor and bitcast it to a + // memref descriptor pointer to minimize the number of GEP + // operations. + UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor()); + Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); + Value scalarMemRefDescPtr = rewriter.create( + loc, + typeConverter.convertType(scalarMemRefType) + .cast() + .getPointerTo(addressSpace), + underlyingRankedDesc); + + // Get pointer to offset field of memref descriptor. + Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace); + Value two = rewriter.create( + loc, typeConverter.convertType(rewriter.getI32Type()), + rewriter.getI32IntegerAttr(2)); + Value offsetPtr = rewriter.create( + loc, indexPtrTy, scalarMemRefDescPtr, + ValueRange({createIndexConstant(rewriter, loc, 0), two})); + + // The size value that we have to extract can be obtained using GEPop with + // `dimOp.index() + 1` index argument. + Value idxPlusOne = rewriter.create( + loc, createIndexConstant(rewriter, loc, 1), transformed.index()); + Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, + ValueRange({idxPlusOne})); + return rewriter.create(loc, sizePtr); + } + + Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Location loc = dimOp.getLoc(); + DimOp::Adaptor transformed(operands); // Take advantage if index is constant. - MemRefType memRefType = dimOp.memrefOrTensor().getType().cast(); + MemRefType memRefType = operandType.cast(); if (Optional index = dimOp.getConstantIndex()) { int64_t i = index.getValue(); if (memRefType.isDynamicDim(i)) { - // Extract dynamic size from the memref descriptor. + // extract dynamic size from the memref descriptor. MemRefDescriptor descriptor(transformed.memrefOrTensor()); - rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)}); - } else { - // Use constant for static size. - int64_t dimSize = memRefType.getDimSize(i); - rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize)); + return descriptor.size(rewriter, loc, i); } - return success(); + // Use constant for static size. + int64_t dimSize = memRefType.getDimSize(i); + return createIndexConstant(rewriter, loc, dimSize); } - Value index = dimOp.index(); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); - rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)}); - return success(); + return memrefDescriptor.size(rewriter, loc, index, rank); } }; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1326,7 +1326,7 @@ } else if (auto memrefType = type.dyn_cast()) { if (index.getValue() >= memrefType.getRank()) return op.emitOpError("index is out of range"); - } else if (type.isa()) { + } else if (type.isa() || type.isa()) { // Assume index to be in range. } else { llvm_unreachable("expected operand with tensor or memref type"); @@ -1342,8 +1342,11 @@ if (!index) return {}; - // Fold if the shape extent along the given index is known. auto argTy = memrefOrTensor().getType(); + if (argTy.isa() || argTy.isa()) { + return {}; + } + // Fold if the shape extent along the given index is known. if (auto shapedTy = argTy.dyn_cast()) { if (!shapedTy.isDynamicDim(index.getInt())) { Builder builder(getContext()); @@ -1357,7 +1360,7 @@ return {}; // The size at the given index is now known to be a dynamic size of a memref. - auto memref = memrefOrTensor().getDefiningOp(); + auto *memref = memrefOrTensor().getDefiningOp(); unsigned unsignedIndex = index.getValue().getZExtValue(); if (auto alloc = dyn_cast_or_null(memref)) return *(alloc.getDynamicSizes().begin() + diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1291,3 +1291,42 @@ func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> { return %arg0 : memref<32xindex> } + +// ----- + +// CHECK-LABEL: func @dim_of_unranked +// CHECK32-LABEL: func @dim_of_unranked +func @dim_of_unranked(%unranked: memref<*xi32>) -> index { + %c0 = constant 0 : index + %dim = dim %unranked, %c0 : memref<*xi32> + return %dim : index +} +// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK-NEXT: llvm.insertvalue +// CHECK-NEXT: %[[UNRANKED_DESC:.*]] = llvm.insertvalue +// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + +// CHECK-NEXT: %[[RANKED_DESC:.*]] = llvm.extractvalue %[[UNRANKED_DESC]][1] +// CHECK-SAME: : !llvm.struct<(i64, ptr)> + +// CHECK-NEXT: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]] +// CHECK-SAME: : !llvm.ptr to !llvm.ptr, ptr, i64)>> + +// CHECK-NEXT: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : !llvm.i32 +// CHECK-NEXT: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + +// CHECK-NEXT: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}} +// CHECK-SAME: %[[C0_]], %[[C2_i32]]] : (!llvm.ptr, ptr, +// CHECK-SAME: i64)>>, !llvm.i64, !llvm.i32) -> !llvm.ptr + +// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %[[C0]] : !llvm.i64 + +// CHECK-NEXT: %[[SIZE_PTR:.*]] = llvm.getelementptr %[[OFFSET_PTR]]{{\[}} +// CHECK-SAME: %[[INDEX_INC]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + +// CHECK-NEXT: %[[SIZE:.*]] = llvm.load %[[SIZE_PTR]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[SIZE]] : !llvm.i64 + +// CHECK32: %[[SIZE:.*]] = llvm.load %{{.*}} : !llvm.ptr +// CHECK32-NEXT: llvm.return %[[SIZE]] : !llvm.i32 diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir --- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -66,6 +66,7 @@ call @return_var_memref_caller() : () -> () call @return_two_var_memref_caller() : () -> () + call @dim_op_of_unranked() : () -> () return } @@ -100,3 +101,25 @@ %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32> return %0 : memref<*xf32> } + +func @print_i64(index) -> () +func @print_newline() -> () + +func @dim_op_of_unranked() { + %ranked = alloc() : memref<4x3xf32> + %unranked = memref_cast %ranked: memref<4x3xf32> to memref<*xf32> + + %c0 = constant 0 : index + %dim_0 = dim %unranked, %c0 : memref<*xf32> + call @print_i64(%dim_0) : (index) -> () + call @print_newline() : () -> () + // CHECK: 4 + + %c1 = constant 1 : index + %dim_1 = dim %unranked, %c1 : memref<*xf32> + call @print_i64(%dim_1) : (index) -> () + call @print_newline() : () -> () + // CHECK: 3 + + return +}