diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -278,6 +278,7 @@ /// Builds IR extracting the pos-th size from the descriptor. Value size(OpBuilder &builder, Location loc, unsigned pos); + Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); /// Builds IR inserting the pos-th size into the descriptor void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -568,6 +568,30 @@ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, + int64_t rank) { + + // Allocate memory for all values on the stack. + auto indexPtrTy = indexType.cast().getPointerTo(); + Value rankVal = createIndexAttrConstant(builder, loc, indexType, rank); + Value sizesPtr = + builder.create(loc, indexPtrTy, rankVal, /*alignment=*/0); + + // Copy values to stack-allocated memory. + for (int64_t i = 0; i < rank; i++) { + Value iVal = createIndexAttrConstant(builder, loc, indexType, i); + Value ithSizePtr = builder.create(loc, indexPtrTy, sizesPtr, + ValueRange({iVal})); + Value ithSizeVal = size(builder, loc, i); + builder.create(loc, ithSizeVal, ithSizePtr); + } + + // Load value of interest. + Value resultPtr = + builder.create(loc, indexPtrTy, sizesPtr, ValueRange({pos})); + return builder.create(loc, resultPtr); +} + /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { @@ -576,7 +600,6 @@ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } -/// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, @@ -598,7 +621,6 @@ builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } -/// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, @@ -2113,25 +2135,29 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); + auto loc = op->getLoc(); OperandAdaptor transformed(operands); - MemRefType type = dimOp.memrefOrTensor().getType().cast(); - Optional index = dimOp.getConstantIndex(); - if (!index.hasValue()) { - // TODO(frgossen): Implement this lowering. - return failure(); + // Take advantage if index is constant. + MemRefType memRefType = dimOp.memrefOrTensor().getType().cast(); + if (Optional index = dimOp.getConstantIndex()) { + int64_t i = index.getValue(); + if (memRefType.isDynamicDim(i)) { + // Extract dynamic size from the memref descriptor. + MemRefDescriptor descriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)}); + } else { + // Use constant for static size. + int64_t dimSize = memRefType.getDimSize(i); + rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize)); + } + return success(); } - int64_t i = index.getValue(); - // Extract dynamic size from the memref descriptor. - if (type.isDynamicDim(i)) - rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) - .size(rewriter, op->getLoc(), i)}); - else - // Use constant for static size. - rewriter.replaceOp( - op, createIndexConstant(rewriter, op->getLoc(), type.getDimSize(i))); - + Value index = dimOp.index(); + int64_t rank = memRefType.getRank(); + MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)}); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1305,7 +1305,7 @@ } OpFoldResult DimOp::fold(ArrayRef operands) { - auto index = operands[1].dyn_cast(); + auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. if (!index) diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -408,3 +408,32 @@ %4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32> return } + +// CHECK-LABEL: @memref_dim_with_dyn_index +// CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm<"float*">, %[[ALIGN_PTR:.*]]: !llvm<"float*">, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64 +func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) + -> index { + // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm<"{ float\*, float\*, i64, \[2 x i64\], \[2 x i64\] }">]] + // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESCR2]][2] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]], %[[DESCR3]][3, 0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESCR4]][4, 0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESCR5]][3, 1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESCR6]][4, 1] : [[DESCR_TY]] + // CHECK-NEXT: %[[RANK:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK-NEXT: %[[SIZES_PTR:.*]] = llvm.alloca %[[RANK]] x !llvm.i64 : (!llvm.i64) -> !llvm<"i64*"> + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[SIZE0_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]]] : (!llvm<"i64*">, !llvm.i64) -> !llvm<"i64*"> + // CHECK-NEXT: %[[SIZE0:.*]] = llvm.extractvalue %[[DESCR7]][3, 0] : [[DESCR_TY]] + // CHECK-NEXT: llvm.store %[[SIZE0]], %[[SIZE0_PTR]] : !llvm<"i64*"> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %[[SIZE1_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C1]]] : (!llvm<"i64*">, !llvm.i64) -> !llvm<"i64*"> + // CHECK-NEXT: %[[SIZE1:.*]] = llvm.extractvalue %[[DESCR7]][3, 1] : [[DESCR_TY]] + // CHECK-NEXT: llvm.store %[[SIZE1]], %[[SIZE1_PTR]] : !llvm<"i64*"> + // CHECK-NEXT: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[IDX]]] : (!llvm<"i64*">, !llvm.i64) -> !llvm<"i64*"> + // CHECK-NEXT: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm<"i64*"> + // CHECK-NEXT: llvm.return %[[RESULT]] : !llvm.i64 + %result = dim %arg, %idx : memref<3x?xf32> + return %result : index +}