diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -349,6 +349,9 @@ let extraClassDeclaration = [{ constexpr static int kDynamicIndex = std::numeric_limits::min(); }]; + let verifier = [{ + return ::verify(*this); + }]; } def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" @@ -359,6 +360,58 @@ // Code for LLVM::GEPOp. //===----------------------------------------------------------------------===// +/// Populates `indices` with positions of GEP indices that would correspond to +/// LLVMStructTypes potentially nested in the given type. The type currently +/// visited gets `currentIndex` and LLVM container types are visited +/// recursively. The recursion is bounded and takes care of recursive types by +/// means of the `visited` set. +static void recordStructIndices(Type type, unsigned currentIndex, + SmallVectorImpl &indices, + SmallVectorImpl *structSizes, + SmallPtrSet &visited) { + if (visited.contains(type)) + return; + + visited.insert(type); + + llvm::TypeSwitch(type) + .Case([&](LLVMStructType structType) { + indices.push_back(currentIndex); + if (structSizes) + structSizes->push_back(structType.getBody().size()); + for (Type elementType : structType.getBody()) + recordStructIndices(elementType, currentIndex + 1, indices, + structSizes, visited); + }) + .Case([&](auto containerType) { + recordStructIndices(containerType.getElementType(), currentIndex + 1, + indices, structSizes, visited); + }); +} + +/// Populates `indices` with positions of GEP indices that correspond to +/// LLVMStructTypes potentially nested in the given `baseGEPType`, which must +/// be either an LLVMPointer type or a vector thereof. If `structSizes` is +/// provided, it is populated with sizes of the indexed structs for bounds +/// verification purposes. +static void +findKnownStructIndices(Type baseGEPType, SmallVectorImpl &indices, + SmallVectorImpl *structSizes = nullptr) { + Type type = baseGEPType; + if (auto vectorType = type.dyn_cast()) + type = vectorType.getElementType(); + if (auto scalableVectorType = type.dyn_cast()) + type = scalableVectorType.getElementType(); + if (auto fixedVectorType = type.dyn_cast()) + type = fixedVectorType.getElementType(); + + Type pointeeType = type.cast().getElementType(); + SmallPtrSet visited; + recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes, + visited); +} + void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value basePtr, ValueRange operands, ArrayRef attributes) { @@ -371,11 +424,58 @@ Value basePtr, ValueRange indices, ArrayRef structIndices, ArrayRef attributes) { + SmallVector remainingIndices; + SmallVector updatedStructIndices(structIndices.begin(), + structIndices.end()); + SmallVector structRelatedPositions; + findKnownStructIndices(basePtr.getType(), structRelatedPositions); + + SmallVector operandsToErase; + for (unsigned pos : structRelatedPositions) { + // GEP may not be indexing as deep as some structs are located. + if (pos >= structIndices.size()) + continue; + + // If the index is already static, it's fine. + if (structIndices[pos] != kDynamicIndex) + continue; + + // Find the corresponding operand. + unsigned operandPos = + std::count(structIndices.begin(), std::next(structIndices.begin(), pos), + kDynamicIndex); + + // Extract the constant value from the operand and put it into the attribute + // instead. + APInt staticIndexValue; + bool matched = + matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); + (void)matched; + assert(matched && "index into a struct must be a constant"); + assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && + "struct index underflows 32-bit integer"); + assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && + "struct index overflows 32-bit integer"); + auto staticIndex = static_cast(staticIndexValue.getSExtValue()); + updatedStructIndices[pos] = staticIndex; + operandsToErase.push_back(operandPos); + } + + for (unsigned i = 0, e = indices.size(); i < e; ++i) { + if (llvm::find(operandsToErase, i) == operandsToErase.end()) + remainingIndices.push_back(indices[i]); + } + + assert(remainingIndices.size() == + llvm::count(updatedStructIndices, kDynamicIndex) && + "exected as many index operands as dynamic index attr elements"); + result.addTypes(resultType); result.addAttributes(attributes); - result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices)); + result.addAttribute("structIndices", + builder.getI32TensorAttr(updatedStructIndices)); result.addOperands(basePtr); - result.addOperands(indices); + result.addOperands(remainingIndices); } static ParseResult @@ -416,6 +516,27 @@ }); } +LogicalResult verify(LLVM::GEPOp gepOp) { + SmallVector indices; + SmallVector structSizes; + findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes); + for (unsigned i = 0, e = indices.size(); i < e; ++i) { + unsigned index = indices[i]; + // GEP may not be indexing as deep as some structs nested in the type. + if (index >= gepOp.getStructIndices().getNumElements()) + continue; + + int32_t staticIndex = gepOp.getStructIndices().getValues()[index]; + if (staticIndex == LLVM::GEPOp::kDynamicIndex) + return gepOp.emitOpError() << "expected index " << index + << " indexing a struct to be constant"; + if (staticIndex < 0 || staticIndex >= structSizes[i]) + return gepOp.emitOpError() + << "index " << index << " indexing a struct is out of bounds"; + } + return success(); +} + //===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -501,8 +501,7 @@ // CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] // CHECK-SAME: !llvm.ptr to !llvm.ptr, ptr, i64, i64)>> // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : i32 -// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]] +// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], 3] // CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]] // CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]] // CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : i64 diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -547,12 +547,11 @@ // CHECK: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]] // CHECK-SAME: : !llvm.ptr to !llvm.ptr, ptr, i64)>> -// CHECK: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}} -// CHECK-SAME: %[[C0_]], %[[C2_i32]]] : (!llvm.ptr, ptr, -// CHECK-SAME: i64)>>, i64, i32) -> !llvm.ptr +// CHECK-SAME: %[[C0_]], 2] : (!llvm.ptr, ptr, +// CHECK-SAME: i64)>>, i64) -> !llvm.ptr // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %{{.*}} : i64 diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -10,7 +10,7 @@ %0 = spv.Constant 1: i32 %1 = spv.Variable : !spv.ptr)>, Function> // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr)>>, i32, i32, i32) -> !llvm.ptr + // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr)>>, i32, i32) -> !llvm.ptr %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 spv.Return } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1234,3 +1234,19 @@ nvvm.cp.async.shared.global %arg0, %arg1, 32 return } + +// ----- + +func @gep_struct_variable(%arg0: !llvm.ptr>, %arg1: i32, %arg2: i32) { + // expected-error @below {{op expected index 1 indexing a struct to be constant}} + llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr>, i32, i32) -> !llvm.ptr + return +} + +// ----- + +func @gep_out_of_bounds(%ptr: !llvm.ptr)>>, %idx: i64) { + // expected-error @below {{index 2 indexing a struct is out of bounds}} + llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr)>>, i64) -> !llvm.ptr + return +} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1416,7 +1416,7 @@ %z32 = llvm.mlir.constant(0 : i32) : i32 %0 = llvm.mlir.undef : !llvm.struct<(i32, !llvm.ptr)> %1 = llvm.mlir.addressof @take_self_address : !llvm.ptr)>> - %2 = llvm.getelementptr %1[%z32, %z32] : (!llvm.ptr)>>, i32, i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[%z32, 0] : (!llvm.ptr)>>, i32) -> !llvm.ptr %3 = llvm.insertvalue %z32, %0[0 : i32] : !llvm.struct<(i32, !llvm.ptr)> %4 = llvm.insertvalue %2, %3[1 : i32] : !llvm.struct<(i32, !llvm.ptr)> llvm.return %4 : !llvm.struct<(i32, !llvm.ptr)>