diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -517,18 +517,18 @@ // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value getStridedElementPtr(Location loc, Type elementTypePtr, - Value descriptor, ValueRange indices, - ArrayRef strides, int64_t offset, + Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, + ValueRange indices, ConversionPatternRewriter &rewriter) const; - /// Returns if the givem memref type is supported. - bool isSupportedMemRefType(MemRefType type) const; - + // Forwards to getStridedElementPtr. TODO: remove. Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const; + /// Returns if the givem memref type is supported. + bool isSupportedMemRefType(MemRefType type) const; + /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType type) const; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1059,39 +1059,45 @@ } Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, Type elementTypePtr, Value descriptor, ValueRange indices, - ArrayRef strides, int64_t offset, + Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - MemRefDescriptor memRefDescriptor(descriptor); + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + + MemRefDescriptor memRefDescriptor(memRefDesc); Value base = memRefDescriptor.alignedPtr(rewriter, loc); - Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.offset(rewriter, loc) - : createIndexConstant(rewriter, loc, offset); + + Value index; + if (offset != 0) // Skip if offset is zero. + index = offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { - Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.stride(rewriter, loc, i) - : createIndexConstant(rewriter, loc, strides[i]); - Value additionalOffset = - rewriter.create(loc, indices[i], stride); - offsetValue = - rewriter.create(loc, offsetValue, additionalOffset); + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.stride(rewriter, loc, i) + : createIndexConstant(rewriter, loc, strides[i]); + increment = rewriter.create(loc, increment, stride); + } + index = + index ? rewriter.create(loc, index, increment) : increment; } - return rewriter.create(loc, elementTypePtr, base, offsetValue); + + LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType(); + return index ? rewriter.create(loc, elementPtrType, base, index) + : base; } Value ConvertToLLVMPattern::getDataPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType(); - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - assert(succeeded(successStrides) && "unexpected non-strided memref"); - (void)successStrides; - return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, - offset, rewriter); + return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter); } // Check if the MemRefType `type` is supported by the lowering. We currently @@ -3092,8 +3098,9 @@ LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); - Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = + getStridedElementPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, dataPtr); return success(); } @@ -3110,8 +3117,9 @@ auto type = cast(op).getMemRefType(); StoreOp::Adaptor transformed(operands); - Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = + getStridedElementPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return success(); @@ -3130,8 +3138,9 @@ PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); - Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = + getStridedElementPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); @@ -3836,8 +3845,9 @@ AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); - auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + auto dataPtr = + getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(), + adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( op, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); @@ -3902,8 +3912,8 @@ // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.memref().getType().cast(); - auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1277,8 +1277,8 @@ // addrspacecast shall be used when source/dst memrefs are not on // address space 0. // TODO: support alignment when possible. - Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); Value vectorDataPtr; diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -102,8 +102,8 @@ // Note that the dataPtr starts at the offset address specified by // indices, so no need to calculate offset size in bytes again in // the MUBUF instruction. - Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); // 1. Create and fill a <4 x i32> dwordConfig with: // 1st two elements holding the address of dataPtr. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -193,13 +193,9 @@ // CHECK: %[[J:.*]]: !llvm.i64) func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %mixed[%i, %j] : memref<42x?xf32> @@ -218,13 +214,9 @@ // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @dynamic_load(%dynamic : memref, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %dynamic[%i, %j] : memref @@ -243,13 +235,9 @@ // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @prefetch(%A : memref, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK-NEXT: [[C3:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 @@ -281,13 +269,9 @@ // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %dynamic[%i, %j] : memref @@ -306,13 +290,9 @@ // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %mixed[%i, %j] : memref<42x?xf32> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -249,14 +249,10 @@ // BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm.ptr) -> !llvm.float func @zero_d_load(%arg0: memref) -> f32 { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm.ptr +// CHECK-NEXT: %{{.*}} = llvm.load %[[ptr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// BAREPTR-NEXT: llvm.load %[[addr:.*]] : !llvm.ptr +// BAREPTR-NEXT: llvm.load %[[ptr:.*]] : !llvm.ptr %0 = load %arg0[] : memref return %0 : f32 } @@ -272,24 +268,16 @@ // BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) { func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %static[%i, %j] : memref<10x42xf32> @@ -303,14 +291,10 @@ // BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr, %[[val:.*]]: !llvm.float) func @zero_d_store(%arg0: memref, %arg1: f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr +// CHECK-NEXT: llvm.store %{{.*}}, %[[ptr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// BAREPTR-NEXT: llvm.store %[[val]], %[[addr]] : !llvm.ptr +// BAREPTR-NEXT: llvm.store %[[val]], %[[ptr]] : !llvm.ptr store %arg1, %arg0[] : memref return } @@ -333,24 +317,16 @@ // BAREPTR-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %static[%i, %j] : memref<10x42xf32>