diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -575,6 +575,8 @@ def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof">, Arguments<(ins FlatSymbolRefAttr:$global_name)> { + let summary = "Creates a pointer pointing to a global or a function"; + let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, LLVMType resType, " "StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{ @@ -586,13 +588,21 @@ "ArrayRef<NamedAttribute> attrs = {}", [{ build(builder, result, global.getType().getPointerTo(global.addr_space().getZExtValue()), - global.sym_name(), attrs);}]> + global.sym_name(), attrs);}]>, + + OpBuilder<"OpBuilder &builder, OperationState &result, LLVMFuncOp func, " + "ArrayRef<NamedAttribute> attrs = {}", [{ + build(builder, result, + func.getType().getPointerTo(), func.getName(), attrs);}]> ]; let extraClassDeclaration = [{ /// Return the llvm.mlir.global operation that defined the value referenced /// here. GlobalOp getGlobal(); + + /// Return the llvm.func operation that is referenced here. + LLVMFuncOp getFunction(); }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; @@ -733,6 +743,7 @@ LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);"> { let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)"; + let verifier = [{ return ::verify(*this); }]; } def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1316,8 +1316,6 @@ using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>; using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>; using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>; -using ConstLLVMOpLowering = - OneToOneConvertToLLVMPattern<ConstantOp, LLVM::ConstantOp>; using CopySignOpLowering = VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>; using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>; @@ -1491,6 +1489,38 @@ } }; +struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> { + using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *operation, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto op = cast<ConstantOp>(operation); + // If constant refers to a function, convert it to "address_of". + if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) { + auto type = typeConverter.convertType(op.getResult().getType()) + .dyn_cast_or_null<LLVM::LLVMType>(); + if (!type) + return failure(); + + MutableDictionaryAttr attrs(op.getAttrs()); + attrs.remove(rewriter.getIdentifier("value")); + rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>( + op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(), + attrs.getAttrs()); + return success(); + } + + // Calling into other scopes (non-flat reference) is not supported in LLVM. + if (op.getValue().isa<SymbolRefAttr>()) + return failure(); + + return LLVM::detail::oneToOneRewrite(op, + LLVM::ConstantOp::getOperationName(), + operands, typeConverter, rewriter); + } +}; + // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { @@ -2980,7 +3010,7 @@ CondBranchOpLowering, CopySignOpLowering, CosOpLowering, - ConstLLVMOpLowering, + ConstantOpLowering, CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -857,25 +857,40 @@ // Verifier for LLVM::AddressOfOp. //===----------------------------------------------------------------------===// -GlobalOp AddressOfOp::getGlobal() { - Operation *module = getParentOp(); +template <typename OpTy> +OpTy lookupSymbolInModule(Operation *parent, StringRef name) { + Operation *module = parent; while (module && !satisfiesLLVMModule(module)) module = module->getParentOp(); assert(module && "unexpected operation outside of a module"); - return dyn_cast_or_null<LLVM::GlobalOp>( - mlir::SymbolTable::lookupSymbolIn(module, global_name())); + return dyn_cast_or_null<OpTy>( + mlir::SymbolTable::lookupSymbolIn(module, name)); +} + +GlobalOp AddressOfOp::getGlobal() { + return lookupSymbolInModule<LLVM::GlobalOp>(getParentOp(), global_name()); +} + +LLVMFuncOp AddressOfOp::getFunction() { + return lookupSymbolInModule<LLVM::LLVMFuncOp>(getParentOp(), global_name()); } static LogicalResult verify(AddressOfOp op) { auto global = op.getGlobal(); - if (!global) + auto function = op.getFunction(); + if (!global && !function) + return op.emitOpError( + "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); + + if (global && + global.getType().getPointerTo(global.addr_space().getZExtValue()) != + op.getResult().getType()) return op.emitOpError( - "must reference a global defined by 'llvm.mlir.global'"); + "the type must be a pointer to the type of the referenced global"); - if (global.getType().getPointerTo(global.addr_space().getZExtValue()) != - op.getResult().getType()) + if (function && function.getType().getPointerTo() != op.getResult().getType()) return op.emitOpError( - "the type must be a pointer to the type of the referred global"); + "the type must be a pointer to the type of the referenced function"); return success(); } @@ -1395,6 +1410,18 @@ return success(); } +//===----------------------------------------------------------------------===// +// Verification for LLVM::ConstantOp. +//===----------------------------------------------------------------------===// + +static LogicalResult verify(LLVM::ConstantOp op) { + if (!(op.value().isa<IntegerAttr>() || op.value().isa<FloatAttr>() || + op.value().isa<ElementsAttr>() || op.value().isa<StringAttr>())) + return op.emitOpError() + << "only supports integer, float, string or elements attributes"; + return success(); +} + //===----------------------------------------------------------------------===// // Utility functions for parsing atomic ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -405,6 +405,9 @@ LLVMType type = processType(c->getType()); if (!type) return nullptr; + if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>()) + return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type, + symbolRef.getValue()); return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr); } if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -447,10 +447,15 @@ // emit any LLVM instruction. if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) { LLVM::GlobalOp global = addressOfOp.getGlobal(); + LLVM::LLVMFuncOp function = addressOfOp.getFunction(); + // The verifier should not have allowed this. - assert(global && "referencing an undefined global"); + assert((global || function) && + "referencing an undefined global or function"); - valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global); + valueMapping[addressOfOp.getResult()] = + global ? globalsMapping.lookup(global) + : functionMapping.lookup(function.getName()); return success(); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir @@ -31,7 +31,7 @@ // CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm<"void ()*">) br ^bb1(%arg0 : () -> ()) -//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): // pred: ^bb0 +//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): ^bb1(%bbarg: () -> ()): // CHECK-NEXT: llvm.return %0 : !llvm<"void ()*"> return %bbarg : () -> () @@ -40,11 +40,12 @@ // CHECK-LABEL: llvm.func @body(!llvm.i32) func @body(i32) -// CHECK-LABEL: llvm.func @indirect_const_call(%arg0: !llvm.i32) { +// CHECK-LABEL: llvm.func @indirect_const_call +// CHECK-SAME: (%[[ARG0:.*]]: !llvm.i32) { func @indirect_const_call(%arg0: i32) { -// CHECK-NEXT: %0 = llvm.mlir.constant(@body) : !llvm<"void (i32)*"> +// CHECK-NEXT: %[[ADDR:.*]] = llvm.mlir.addressof @body : !llvm<"void (i32)*"> %0 = constant @body : (i32) -> () -// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> () +// CHECK-NEXT: llvm.call %[[ADDR]](%[[ARG0:.*]]) : (!llvm.i32) -> () call_indirect %0(%arg0) : (i32) -> () // CHECK-NEXT: llvm.return return diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -140,12 +140,21 @@ llvm.mlir.global internal @foo(0: i32) : !llvm.i32 func @bar() { - // expected-error @+1 {{the type must be a pointer to the type of the referred global}} + // expected-error @+1 {{the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @foo : !llvm<"i64*"> } // ----- +llvm.func @foo() + +llvm.func @bar() { + // expected-error @+1 {{the type must be a pointer to the type of the referenced function}} + llvm.mlir.addressof @foo : !llvm<"i8*"> +} + +// ----- + // expected-error @+2 {{'llvm.mlir.global' op expects regions to end with 'llvm.return', found 'llvm.mlir.constant'}} // expected-note @+1 {{in custom textual format, the absence of terminator implies 'llvm.return'}} llvm.mlir.global internal @g() : !llvm.i64 { @@ -172,7 +181,7 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64 func @mismatch_addr_space_implicit_global() { - // expected-error @+1 {{op the type must be a pointer to the type of the referred global}} + // expected-error @+1 {{op the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @g : !llvm<"i64*"> } @@ -180,6 +189,6 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64 func @mismatch_addr_space() { - // expected-error @+1 {{op the type must be a pointer to the type of the referred global}} + // expected-error @+1 {{op the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @g : !llvm<"i64 addrspace(4)*"> } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -153,6 +153,13 @@ // ----- +func @constant_wrong_type() { + // expected-error@+1 {{only supports integer, float, string or elements attributes}} + llvm.mlir.constant(@constant_wrong_type) : !llvm<"void ()*"> +} + +// ----- + func @insertvalue_non_llvm_type(%a : i32, %b : i32) { // expected-error@+1 {{expected LLVM IR Dialect type}} llvm.insertvalue %a, %b[0] : i32 diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -52,12 +52,12 @@ // CHECK-NEXT: %17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> // CHECK-NEXT: %18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }"> // CHECK-NEXT: %19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }"> -// CHECK-NEXT: %20 = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*"> +// CHECK-NEXT: %20 = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*"> // CHECK-NEXT: %21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> %17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> %18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }"> %19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }"> - %20 = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*"> + %20 = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*"> %21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> @@ -114,8 +114,8 @@ } // An larger self-contained function. -// CHECK-LABEL:func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { -func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { +// CHECK-LABEL:llvm.func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { +llvm.func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { // CHECK-NEXT: %0 = llvm.mlir.constant(3 : i64) : !llvm.i32 // CHECK-NEXT: %1 = llvm.mlir.constant(3 : i64) : !llvm.i32 // CHECK-NEXT: %2 = llvm.mlir.constant(4.200000e+01 : f64) : !llvm.double @@ -315,4 +315,4 @@ // CHECK: release llvm.fence release return -} \ No newline at end of file +} diff --git a/mlir/test/Target/import.ll b/mlir/test/Target/import.ll --- a/mlir/test/Target/import.ll +++ b/mlir/test/Target/import.ll @@ -234,7 +234,7 @@ ; CHECK-LABEL: @precaller define i32 @precaller() { %1 = alloca i32 ()* - ; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*"> + ; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*"> ; CHECK: llvm.store %[[func]], %[[loc:.*]] store i32 ()* @callee, i32 ()** %1 ; CHECK: %[[indir:.*]] = llvm.load %[[loc]] @@ -252,7 +252,7 @@ ; CHECK-LABEL: @postcaller define i32 @postcaller() { %1 = alloca i32 ()* - ; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*"> + ; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*"> ; CHECK: llvm.store %[[func]], %[[loc:.*]] store i32 ()* @callee, i32 ()** %1 ; CHECK: %[[indir:.*]] = llvm.load %[[loc]] @@ -317,4 +317,4 @@ ;CHECK: llvm.fence seq_cst fence syncscope("") seq_cst ret i32 0 -} \ No newline at end of file +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -886,7 +886,7 @@ // CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}}) llvm.func @indirect_const_call(%arg0: !llvm.i64) { // CHECK-NEXT: call void @body(i64 %0) - %0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*"> + %0 = llvm.mlir.addressof @body : !llvm<"void (i64)*"> llvm.call %0(%arg0) : (!llvm.i64) -> () // CHECK-NEXT: ret void llvm.return