diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2304,7 +2304,8 @@ // Currently, only rank > 0 and full or no operands are supported. Fail to // convert otherwise. unsigned rank = sourceMemRefType.getRank(); - if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) || + if (viewMemRefType.getRank() == 0 || + (!dynamicOffsets.empty() && rank != dynamicOffsets.size()) || (!dynamicSizes.empty() && rank != dynamicSizes.size()) || (!dynamicStrides.empty() && rank != dynamicStrides.size())) return matchFailure(); @@ -2315,6 +2316,11 @@ if (failed(successStrides)) return matchFailure(); + // Fail to convert if neither a dynamic nor static offset is available. + if (dynamicOffsets.empty() && + offset == MemRefType::getDynamicStrideOrOffset()) + return matchFailure(); + // Create the descriptor. MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); @@ -2348,14 +2354,18 @@ } // Offset. - Value baseOffset = sourceMemRef.offset(rewriter, loc); - for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { - Value min = dynamicOffsets[i]; - baseOffset = rewriter.create( - loc, baseOffset, - rewriter.create(loc, min, strideValues[i])); + if (dynamicOffsets.empty()) { + targetMemRef.setConstantOffset(rewriter, loc, offset); + } else { + Value baseOffset = sourceMemRef.offset(rewriter, loc); + for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + Value min = dynamicOffsets[i]; + baseOffset = rewriter.create( + loc, baseOffset, + rewriter.create(loc, min, strideValues[i])); + } + targetMemRef.setOffset(rewriter, loc, baseOffset); } - targetMemRef.setOffset(rewriter, loc, baseOffset); // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -815,6 +815,31 @@ return } +// CHECK-LABEL: func @subview_const_stride_and_offset( +func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) { + // The last "insertvalue" that populates the memref descriptor from the function arguments. + // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) + // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i64) + // CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : index) + // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) + // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %1 = subview %0[][][] : + memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>> + return +} + // ----- module {