diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -315,17 +315,39 @@ let printer = [{ printAllocaOp(p, *this); }]; } -def LLVM_GEPOp - : LLVM_Op<"getelementptr", [NoSideEffect]>, - LLVM_Builder< - "$res = builder.CreateGEP(" - " $base->getType()->getPointerElementType(), $base, $indices);"> { +def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, - Variadic>:$indices); + Variadic>:$indices, + I32ElementsAttr:$structIndices); let results = (outs LLVM_ScalarOrVectorOf:$res); - let builders = [LLVM_OneResultOpBuilder]; + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, + CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, + "ArrayRef":$structIndices, + CArg<"ArrayRef", "{}">:$attributes)>, + ]; + let llvmBuilder = [{ + SmallVector indices; + indices.reserve($structIndices.size()); + unsigned operandIdx = 0; + for (int32_t structIndex : $structIndices.getValues()) { + if (structIndex == GEPOp::kDynamicIndex) + indices.push_back($indices[operandIdx++]); + else + indices.push_back(builder.getInt32(structIndex)); + } + $res = builder.CreateGEP( + $base->getType()->getPointerElementType(), $base, indices); + }]; let assemblyFormat = [{ - $base `[` $indices `]` attr-dict `:` functional-type(operands, results) + $base `[` custom($indices, $structIndices) `]` attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + constexpr static int kDynamicIndex = std::numeric_limits::min(); }]; } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -790,8 +790,8 @@ Type elementPtrType = getElementPtrType(memRefType); Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, ArrayRef{nullPtr, numElements}); + Value gepPtr = rewriter.create(loc, elementPtrType, nullPtr, + ArrayRef{numElements}); auto sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -162,8 +162,8 @@ // Buffer size in bytes. Type elementPtrType = getElementPtrType(memRefType); Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, ArrayRef{nullPtr, runningStride}); + Value gepPtr = rewriter.create(loc, elementPtrType, nullPtr, + ArrayRef{runningStride}); sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); } @@ -178,8 +178,8 @@ LLVM::LLVMPointerType::get(typeConverter->convertType(type)); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( - loc, convertedPtrType, - ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); + loc, convertedPtrType, nullPtr, + ArrayRef{createIndexConstant(rewriter, loc, 1)}); return rewriter.create(loc, getIndexType(), gep); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -497,10 +497,11 @@ Type elementType = typeConverter->convertType(type.getElementType()); Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); - SmallVector operands = {addressOf}; + SmallVector operands; operands.insert(operands.end(), type.getRank() + 1, createIndexConstant(rewriter, loc, 0)); - auto gep = rewriter.create(loc, elementPtrType, operands); + auto gep = + rewriter.create(loc, elementPtrType, addressOf, operands); // We do not expect the memref obtained using `memref.get_global` to be // ever deallocated. Set the allocated pointer to be known bad value to diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -355,6 +355,67 @@ : getCaseOperandsMutable(index - 1); } +//===----------------------------------------------------------------------===// +// Code for LLVM::GEPOp. +//===----------------------------------------------------------------------===// + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ValueRange operands, + ArrayRef attributes) { + build(builder, result, resultType, basePtr, operands, + SmallVector(operands.size(), LLVM::GEPOp::kDynamicIndex), + attributes); +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ValueRange indices, + ArrayRef structIndices, + ArrayRef attributes) { + result.addTypes(resultType); + result.addAttributes(attributes); + result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices)); + result.addOperands(basePtr); + result.addOperands(indices); +} + +static ParseResult +parseGEPIndices(OpAsmParser &parser, + SmallVectorImpl &indices, + DenseIntElementsAttr &structIndices) { + SmallVector constantIndices; + do { + int32_t constantIndex; + OptionalParseResult parsedInteger = + parser.parseOptionalInteger(constantIndex); + if (parsedInteger.hasValue()) { + if (failed(parsedInteger.getValue())) + return failure(); + constantIndices.push_back(constantIndex); + continue; + } + + constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); + if (failed(parser.parseOperand(indices.emplace_back()))) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + + structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); + return success(); +} + +static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, + OperandRange indices, + DenseIntElementsAttr structIndices) { + unsigned operandIdx = 0; + llvm::interleaveComma(structIndices.getValues(), printer, + [&](int32_t cst) { + if (cst == LLVM::GEPOp::kDynamicIndex) + printer.printOperand(indices[operandIdx++]); + else + printer << cst; + }); +} + //===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -760,7 +760,8 @@ Type type = processType(inst->getType()); if (!type) return failure(); - v = b.create(loc, type, ops); + v = b.create(loc, type, ops[0], + llvm::makeArrayRef(ops).drop_front()); return success(); } } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -170,6 +170,16 @@ llvm.return } +// CHECK-LABEL: @gep +llvm.func @gep(%ptr: !llvm.ptr)>>, %idx: i64, + %ptr2: !llvm.ptr)>>) { + // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr + llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr + // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.return +} + // An larger self-contained function. // CHECK-LABEL: llvm.func @foo(%{{.*}}: i32) -> !llvm.struct<(i32, f64, i32)> { llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> { diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -975,6 +975,16 @@ llvm.return %10 : !llvm.struct<(f32, i32)> } +// CHECK-LABEL: @gep +llvm.func @gep(%ptr: !llvm.ptr)>>, %idx: i64, + %ptr2: !llvm.ptr)>>) { + // CHECK: = getelementptr { i32, { i32, float } }, { i32, { i32, float } }* %{{.*}}, i64 %{{.*}}, i32 1, i32 0 + llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr + // CHECK: = getelementptr { [10 x float] }, { [10 x float] }* %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} + llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.return +} + // // Indirect function calls //