diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1518,6 +1518,7 @@ ConversionPatternRewriter &rewriter, Location loc, Operation *op, TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, ArrayRef reassociation) { + auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); // See comments for computeCollapsedLayoutMap for details on how the strides // are calculated. auto srcShape = srcType.getShape(); @@ -1579,8 +1580,8 @@ rewriter.create(loc, srcStride, continueBlock); break; } - Value one = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI32IntegerAttr(1)); + Value one = rewriter.create(loc, llvmIndexType, + rewriter.getIndexAttr(1)); Value predNeOne = rewriter.create( loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), one); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -776,7 +776,7 @@ // CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : i32) : i64 +// CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64 // CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1 @@ -785,6 +785,10 @@ // CHECK: llvm.br ^bb2(%{{.*}} : i64) // CHECK: ^bb2(%[[STRIDE:.*]]: i64): // CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( +// CHECK32: llvm.mlir.constant(1 : index) : i32 +// CHECK32: llvm.mlir.constant(4 : index) : i32 +// CHECK32: llvm.mlir.constant(1 : index) : i32 // ----- @@ -1149,7 +1153,7 @@ // CHECK-LABEL: func @extract_aligned_pointer_as_index func.func @extract_aligned_pointer_as_index(%m: memref) -> index { %0 = memref.extract_aligned_pointer_as_index %m: memref -> index - // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[E]] : !llvm.ptr to i64 // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index