diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -278,6 +278,7 @@ /// Builds IR extracting the pos-th size from the descriptor. Value size(OpBuilder &builder, Location loc, unsigned pos); + Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); /// Builds IR inserting the pos-th size into the descriptor void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -568,6 +568,29 @@ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, + int64_t rank) { + auto indexTy = indexType.cast(); + auto indexPtrTy = indexTy.getPointerTo(); + auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank); + auto arrayPtrTy = arrayTy.getPointerTo(); + + // Copy size values to stack-allocated memory. + auto zero = createIndexAttrConstant(builder, loc, indexType, 0); + auto one = createIndexAttrConstant(builder, loc, indexType, 1); + auto sizes = builder.create( + loc, arrayTy, value, + builder.getI64ArrayAttr({kSizePosInMemRefDescriptor})); + auto sizesPtr = + builder.create(loc, arrayPtrTy, one, /*alignment=*/0); + builder.create(loc, sizes, sizesPtr); + + // Load an return size value of interest. + auto resultPtr = builder.create(loc, indexPtrTy, sizesPtr, + ValueRange({zero, pos})); + return builder.create(loc, resultPtr); +} + /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { @@ -576,7 +599,6 @@ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } -/// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, @@ -598,7 +620,6 @@ builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } -/// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, @@ -2117,25 +2138,29 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); + auto loc = op->getLoc(); DimOp::Adaptor transformed(operands); - MemRefType type = dimOp.memrefOrTensor().getType().cast(); - Optional index = dimOp.getConstantIndex(); - if (!index.hasValue()) { - // TODO: Implement this lowering. - return failure(); + // Take advantage if index is constant. + MemRefType memRefType = dimOp.memrefOrTensor().getType().cast(); + if (Optional index = dimOp.getConstantIndex()) { + int64_t i = index.getValue(); + if (memRefType.isDynamicDim(i)) { + // Extract dynamic size from the memref descriptor. + MemRefDescriptor descriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)}); + } else { + // Use constant for static size. + int64_t dimSize = memRefType.getDimSize(i); + rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize)); + } + return success(); } - int64_t i = index.getValue(); - // Extract dynamic size from the memref descriptor. - if (type.isDynamicDim(i)) - rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) - .size(rewriter, op->getLoc(), i)}); - else - // Use constant for static size. - rewriter.replaceOp( - op, createIndexConstant(rewriter, op->getLoc(), type.getDimSize(i))); - + Value index = dimOp.index(); + int64_t rank = memRefType.getRank(); + MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)}); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1308,7 +1308,7 @@ } OpFoldResult DimOp::fold(ArrayRef operands) { - auto index = operands[1].dyn_cast(); + auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. if (!index) diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -408,3 +408,26 @@ %4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32> return } + +// CHECK-LABEL: @memref_dim_with_dyn_index +// CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm<"float*">, %[[ALIGN_PTR:.*]]: !llvm<"float*">, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64 +func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index { + // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm<"{ float\*, float\*, i64, \[2 x i64\], \[2 x i64\] }">]] + // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESCR2]][2] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]], %[[DESCR3]][3, 0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESCR4]][4, 0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESCR5]][3, 1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESCR6]][4, 1] : [[DESCR_TY]] + // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %[[DESCR7]][3] : [[DESCR_TY]] + // CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm<"[2 x i64]"> : (!llvm.i64) -> !llvm<"[2 x i64]*"> + // CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm<"[2 x i64]*"> + // CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm<"[2 x i64]*">, !llvm.i64, !llvm.i64) -> !llvm<"i64*"> + // CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm<"i64*"> + // CHECK-DAG: llvm.return %[[RESULT]] : !llvm.i64 + %result = dim %arg, %idx : memref<3x?xf32> + return %result : index +}