diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -278,6 +278,7 @@ /// Builds IR extracting the pos-th size from the descriptor. Value size(OpBuilder &builder, Location loc, unsigned pos); + Value size(OpBuilder &builder, Location loc, Value pos); /// Builds IR inserting the pos-th size into the descriptor void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); @@ -286,6 +287,7 @@ /// Builds IR extracting the pos-th size from the descriptor. Value stride(OpBuilder &builder, Location loc, unsigned pos); + Value stride(OpBuilder &builder, Location loc, Value pos); /// Builds IR inserting the pos-th stride into the descriptor void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -568,6 +568,14 @@ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos) { + Value sizePos = createIndexAttrConstant(builder, loc, indexType, + kSizePosInMemRefDescriptor); + Value ptr = builder.create(loc, indexType, value, + ValueRange({sizePos, pos})); + return builder.create(loc, indexType, ptr); +} + /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { @@ -576,7 +584,6 @@ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } -/// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, @@ -590,6 +597,14 @@ builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } +Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, Value pos) { + Value stridePos = createIndexAttrConstant(builder, loc, indexType, + kStridePosInMemRefDescriptor); + Value ptr = builder.create(loc, indexType, value, + ValueRange({stridePos, pos})); + return builder.create(loc, indexType, ptr); +} + /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { @@ -598,7 +613,6 @@ builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } -/// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, @@ -2113,25 +2127,28 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); + auto loc = op->getLoc(); OperandAdaptor transformed(operands); - MemRefType type = dimOp.memrefOrTensor().getType().cast(); - Optional index = dimOp.getConstantIndex(); - if (!index.hasValue()) { - // TODO(frgossen): Implement this lowering. - return failure(); + // Take advantage if index is constant. + if (Optional index = dimOp.getConstantIndex()) { + MemRefType type = dimOp.memrefOrTensor().getType().cast(); + int64_t i = index.getValue(); + if (type.isDynamicDim(i)) { + // Extract dynamic size from the memref descriptor. + MemRefDescriptor descriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)}); + } else { + // Use constant for static size. + int64_t dimSize = type.getDimSize(i); + rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize)); + } + return success(); } - int64_t i = index.getValue(); - // Extract dynamic size from the memref descriptor. - if (type.isDynamicDim(i)) - rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) - .size(rewriter, op->getLoc(), i)}); - else - // Use constant for static size. - rewriter.replaceOp( - op, createIndexConstant(rewriter, op->getLoc(), type.getDimSize(i))); - + Value index = dimOp.index(); + MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index)}); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1305,7 +1305,7 @@ } OpFoldResult DimOp::fold(ArrayRef operands) { - auto index = operands[1].dyn_cast(); + auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. if (!index) diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -408,3 +408,25 @@ %4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32> return } + +// CHECK-LABEL: @memref_dim_with_dyn_index +// CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm<"float*">, %[[ALIGN_PTR:.*]]: !llvm<"float*">, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[SIZE2:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[STRIDE2:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64 +func @memref_dim_with_dyn_index(%arg : memref<3x4x5xf32>, %idx : index) + -> index { + // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm<"{ float\*, float\*, i64, \[3 x i64\], \[3 x i64\] }">]] + // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESCR2]][2] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]], %[[DESCR3]][3, 0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESCR4]][4, 0] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESCR5]][3, 1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESCR6]][4, 1] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR8:.*]] = llvm.insertvalue %[[SIZE2]], %[[DESCR7]][3, 2] : [[DESCR_TY]] + // CHECK-NEXT: %[[DESCR9:.*]] = llvm.insertvalue %[[STRIDE2]], %[[DESCR8]][4, 2] : [[DESCR_TY]] + // CHECK-NEXT: %[[C3:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK-NEXT: %[[PTR:.*]] = llvm.getelementptr %[[DESCR9]][%[[C3]], %[[IDX]]] : ([[DESCR_TY]], !llvm.i64, !llvm.i64) -> !llvm.i64 + // CHECK-NEXT: %[[RESULT:.*]] = llvm.load %[[PTR]] : !llvm.i64 + // CHECK-NEXT: llvm.return %[[RESULT]] : !llvm.i64 + %result = dim %arg, %idx : memref<3x4x5xf32> + return %result : index +}