diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -313,10 +313,28 @@ Selection: `select , , `. -### Auxiliary MLIR operations - -These operations do not have LLVM IR counterparts but are necessary to map LLVM -IR into MLIR. They should be prefixed with `llvm.mlir`. +### Auxiliary MLIR Operations for Constants and Globals + +LLVM IR has broad support for first-class constants, which is not the case for +MLIR. Instead, constants are defined in MLIR as regular SSA values produced by +operations with specific traits. The LLVM dialect provides a set of operations +that model LLVM IR constants. These operations do not correspond to LLVM IR +instructions and are therefore prefixed with `llvm.mlir`. + +Inline constants can be created by `llvm.mlir.constant`, which currently +supports integer, float, string or elements attributes (constant sturcts are not +currently supported). LLVM IR constant expressions are expected to be +constructed as sequences of regular operations on SSA values produced by +`llvm.mlir.constant`. Additionally, MLIR provides semantically-charged +operations `llvm.mlir.undef` and `llvm.mlir.null` for the corresponding +constants. + +LLVM IR globals can be defined using `llvm.mlir.global` at the module level, +except for functions that are defined with `llvm.func`. Globals, both variables +and functions, can be accessed by taking their address with the +`llvm.mlir.addressof` operation, which produces a pointer to the named global, +unlike the `llvm.mlir.constant` that produces the value of the same type as the +constant. #### `llvm.mlir.addressof` @@ -328,11 +346,17 @@ ```mlir func @foo() { - // Get the address of a global. + // Get the address of a global variable. %0 = llvm.mlir.addressof @const : !llvm<"i32*"> // Use it as a regular pointer. %1 = llvm.load %0 : !llvm<"i32*"> + + // Get the address of a function. + %2 = llvm.mlir.addressof @foo : !llvm<"void ()*"> + + // The function address can be used for indirect calls. + llvm.call %2() : () -> () } // Define the global. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -575,6 +575,8 @@ def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof">, Arguments<(ins FlatSymbolRefAttr:$global_name)> { + let summary = "Creates a pointer pointing to a global or a function"; + let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, LLVMType resType, " "StringRef name, ArrayRef attrs = {}", [{ @@ -586,13 +588,21 @@ "ArrayRef attrs = {}", [{ build(builder, result, global.getType().getPointerTo(global.addr_space().getZExtValue()), - global.sym_name(), attrs);}]> + global.sym_name(), attrs);}]>, + + OpBuilder<"OpBuilder &builder, OperationState &result, LLVMFuncOp func, " + "ArrayRef attrs = {}", [{ + build(builder, result, + func.getType().getPointerTo(), func.getName(), attrs);}]> ]; let extraClassDeclaration = [{ /// Return the llvm.mlir.global operation that defined the value referenced /// here. GlobalOp getGlobal(); + + /// Return the llvm.func operation that is referenced here. + LLVMFuncOp getFunction(); }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; @@ -733,6 +743,7 @@ LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);"> { let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)"; + let verifier = [{ return ::verify(*this); }]; } def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1366,8 +1366,6 @@ using AddIOpLowering = VectorConvertToLLVMPattern; using AndOpLowering = VectorConvertToLLVMPattern; using CeilFOpLowering = VectorConvertToLLVMPattern; -using ConstLLVMOpLowering = - OneToOneConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; @@ -1541,6 +1539,39 @@ } }; +struct ConstantOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto op = cast(operation); + // If constant refers to a function, convert it to "addressof". + if (auto symbolRef = op.getValue().dyn_cast()) { + auto type = typeConverter.convertType(op.getResult().getType()) + .dyn_cast_or_null(); + if (!type) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + MutableDictionaryAttr attrs(op.getAttrs()); + attrs.remove(rewriter.getIdentifier("value")); + rewriter.replaceOpWithNewOp( + op, type.cast(), symbolRef.getValue(), + attrs.getAttrs()); + return success(); + } + + // Calling into other scopes (non-flat reference) is not supported in LLVM. + if (op.getValue().isa()) + return rewriter.notifyMatchFailure( + op, "referring to a symbol outside of the current module"); + + return LLVM::detail::oneToOneRewrite(op, + LLVM::ConstantOp::getOperationName(), + operands, typeConverter, rewriter); + } +}; + // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { @@ -3129,7 +3160,7 @@ CondBranchOpLowering, CopySignOpLowering, CosOpLowering, - ConstLLVMOpLowering, + ConstantOpLowering, CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -857,25 +857,40 @@ // Verifier for LLVM::AddressOfOp. //===----------------------------------------------------------------------===// -GlobalOp AddressOfOp::getGlobal() { - Operation *module = getParentOp(); +template +static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { + Operation *module = parent; while (module && !satisfiesLLVMModule(module)) module = module->getParentOp(); assert(module && "unexpected operation outside of a module"); - return dyn_cast_or_null( - mlir::SymbolTable::lookupSymbolIn(module, global_name())); + return dyn_cast_or_null( + mlir::SymbolTable::lookupSymbolIn(module, name)); +} + +GlobalOp AddressOfOp::getGlobal() { + return lookupSymbolInModule(getParentOp(), global_name()); +} + +LLVMFuncOp AddressOfOp::getFunction() { + return lookupSymbolInModule(getParentOp(), global_name()); } static LogicalResult verify(AddressOfOp op) { auto global = op.getGlobal(); - if (!global) + auto function = op.getFunction(); + if (!global && !function) + return op.emitOpError( + "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); + + if (global && + global.getType().getPointerTo(global.addr_space().getZExtValue()) != + op.getResult().getType()) return op.emitOpError( - "must reference a global defined by 'llvm.mlir.global'"); + "the type must be a pointer to the type of the referenced global"); - if (global.getType().getPointerTo(global.addr_space().getZExtValue()) != - op.getResult().getType()) + if (function && function.getType().getPointerTo() != op.getResult().getType()) return op.emitOpError( - "the type must be a pointer to the type of the referred global"); + "the type must be a pointer to the type of the referenced function"); return success(); } @@ -1396,6 +1411,18 @@ } //===----------------------------------------------------------------------===// +// Verification for LLVM::ConstantOp. +//===----------------------------------------------------------------------===// + +static LogicalResult verify(LLVM::ConstantOp op) { + if (!(op.value().isa() || op.value().isa() || + op.value().isa() || op.value().isa())) + return op.emitOpError() + << "only supports integer, float, string or elements attributes"; + return success(); +} + +//===----------------------------------------------------------------------===// // Utility functions for parsing atomic ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -405,6 +405,9 @@ LLVMType type = processType(c->getType()); if (!type) return nullptr; + if (auto symbolRef = attr.dyn_cast()) + return instMap[c] = bEntry.create(unknownLoc, type, + symbolRef.getValue()); return instMap[c] = bEntry.create(unknownLoc, type, attr); } if (auto *cn = dyn_cast(c)) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -447,10 +447,15 @@ // emit any LLVM instruction. if (auto addressOfOp = dyn_cast(opInst)) { LLVM::GlobalOp global = addressOfOp.getGlobal(); + LLVM::LLVMFuncOp function = addressOfOp.getFunction(); + // The verifier should not have allowed this. - assert(global && "referencing an undefined global"); + assert((global || function) && + "referencing an undefined global or function"); - valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global); + valueMapping[addressOfOp.getResult()] = + global ? globalsMapping.lookup(global) + : functionMapping.lookup(function.getName()); return success(); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir @@ -31,7 +31,7 @@ // CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm<"void ()*">) br ^bb1(%arg0 : () -> ()) -//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): // pred: ^bb0 +//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): ^bb1(%bbarg: () -> ()): // CHECK-NEXT: llvm.return %0 : !llvm<"void ()*"> return %bbarg : () -> () @@ -40,11 +40,12 @@ // CHECK-LABEL: llvm.func @body(!llvm.i32) func @body(i32) -// CHECK-LABEL: llvm.func @indirect_const_call(%arg0: !llvm.i32) { +// CHECK-LABEL: llvm.func @indirect_const_call +// CHECK-SAME: (%[[ARG0:.*]]: !llvm.i32) { func @indirect_const_call(%arg0: i32) { -// CHECK-NEXT: %0 = llvm.mlir.constant(@body) : !llvm<"void (i32)*"> +// CHECK-NEXT: %[[ADDR:.*]] = llvm.mlir.addressof @body : !llvm<"void (i32)*"> %0 = constant @body : (i32) -> () -// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> () +// CHECK-NEXT: llvm.call %[[ADDR]](%[[ARG0:.*]]) : (!llvm.i32) -> () call_indirect %0(%arg0) : (i32) -> () // CHECK-NEXT: llvm.return return diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -140,12 +140,21 @@ llvm.mlir.global internal @foo(0: i32) : !llvm.i32 func @bar() { - // expected-error @+1 {{the type must be a pointer to the type of the referred global}} + // expected-error @+1 {{the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @foo : !llvm<"i64*"> } // ----- +llvm.func @foo() + +llvm.func @bar() { + // expected-error @+1 {{the type must be a pointer to the type of the referenced function}} + llvm.mlir.addressof @foo : !llvm<"i8*"> +} + +// ----- + // expected-error @+2 {{'llvm.mlir.global' op expects regions to end with 'llvm.return', found 'llvm.mlir.constant'}} // expected-note @+1 {{in custom textual format, the absence of terminator implies 'llvm.return'}} llvm.mlir.global internal @g() : !llvm.i64 { @@ -172,7 +181,7 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64 func @mismatch_addr_space_implicit_global() { - // expected-error @+1 {{op the type must be a pointer to the type of the referred global}} + // expected-error @+1 {{op the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @g : !llvm<"i64*"> } @@ -180,6 +189,6 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64 func @mismatch_addr_space() { - // expected-error @+1 {{op the type must be a pointer to the type of the referred global}} + // expected-error @+1 {{op the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @g : !llvm<"i64 addrspace(4)*"> } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -153,6 +153,13 @@ // ----- +func @constant_wrong_type() { + // expected-error@+1 {{only supports integer, float, string or elements attributes}} + llvm.mlir.constant(@constant_wrong_type) : !llvm<"void ()*"> +} + +// ----- + func @insertvalue_non_llvm_type(%a : i32, %b : i32) { // expected-error@+1 {{expected LLVM IR Dialect type}} llvm.insertvalue %a, %b[0] : i32 diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -55,12 +55,12 @@ // CHECK: %[[STRUCT:.*]] = llvm.call @foo(%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> // CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0] : !llvm<"{ i32, double, i32 }"> // CHECK: %[[NEW_STRUCT:.*]] = llvm.insertvalue %[[VALUE]], %[[STRUCT]][2] : !llvm<"{ i32, double, i32 }"> -// CHECK: %[[FUNC:.*]] = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*"> +// CHECK: %[[FUNC:.*]] = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*"> // CHECK: %{{.*}} = llvm.call %[[FUNC]](%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> %17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> %18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }"> %19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }"> - %20 = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*"> + %20 = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*"> %21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }"> @@ -130,8 +130,8 @@ } // An larger self-contained function. -// CHECK-LABEL: func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { -func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { +// CHECK-LABEL: llvm.func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { +llvm.func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> { // CHECK: %[[V0:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32 // CHECK: %[[V1:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32 // CHECK: %[[V2:.*]] = llvm.mlir.constant(4.200000e+01 : f64) : !llvm.double diff --git a/mlir/test/Target/import.ll b/mlir/test/Target/import.ll --- a/mlir/test/Target/import.ll +++ b/mlir/test/Target/import.ll @@ -234,7 +234,7 @@ ; CHECK-LABEL: @precaller define i32 @precaller() { %1 = alloca i32 ()* - ; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*"> + ; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*"> ; CHECK: llvm.store %[[func]], %[[loc:.*]] store i32 ()* @callee, i32 ()** %1 ; CHECK: %[[indir:.*]] = llvm.load %[[loc]] @@ -252,7 +252,7 @@ ; CHECK-LABEL: @postcaller define i32 @postcaller() { %1 = alloca i32 ()* - ; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*"> + ; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*"> ; CHECK: llvm.store %[[func]], %[[loc:.*]] store i32 ()* @callee, i32 ()** %1 ; CHECK: %[[indir:.*]] = llvm.load %[[loc]] @@ -317,4 +317,4 @@ ;CHECK: llvm.fence seq_cst fence syncscope("") seq_cst ret i32 0 -} \ No newline at end of file +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -886,7 +886,7 @@ // CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}}) llvm.func @indirect_const_call(%arg0: !llvm.i64) { // CHECK-NEXT: call void @body(i64 %0) - %0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*"> + %0 = llvm.mlir.addressof @body : !llvm<"void (i64)*"> llvm.call %0(%arg0) : (!llvm.i64) -> () // CHECK-NEXT: ret void llvm.return