diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -15,6 +15,7 @@ let name = "llvm"; let cppNamespace = "::mlir::LLVM"; + let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -98,9 +98,13 @@ def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; -def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or">; +def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> { + let hasFolder = 1; +} def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; -def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl">; +def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> { + let hasFolder = 1; +} def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">; def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">; @@ -495,10 +499,15 @@ LLVM_ScalarOrVectorOf>; def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt", LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf> { + let hasVerifier = 1; +} def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt", LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf> { + let hasFolder = 1; + let hasVerifier = 1; +} def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -51,6 +51,17 @@ return false; } +/// Returns the width of an integer or of the element type of an integer vector, +/// if applicable. +static std::optional getIntegerOrVectorElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto vecType = dyn_cast(type)) + if (auto intType = dyn_cast(vecType.getElementType())) + return intType.getWidth(); + return std::nullopt; +} + /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { assert((type.isIntOrFloat() || isa(type)) && @@ -1183,15 +1194,30 @@ return success(); } + std::optional dstTypeWidth = + getIntegerOrVectorElementWidth(dstType); + std::optional op2TypeWidth = + getIntegerOrVectorElementWidth(op2Type); + + if (!dstTypeWidth || !op2TypeWidth) + return failure(); + Location loc = operation.getLoc(); Value extended; - if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create(loc, dstType, - adaptor.getOperand2()); + if (op2TypeWidth < dstTypeWidth) { + if (isUnsignedIntegerOrVector(op2Type)) { + extended = rewriter.template create( + loc, dstType, adaptor.getOperand2()); + } else { + extended = rewriter.template create( + loc, dstType, adaptor.getOperand2()); + } + } else if (op2TypeWidth == dstTypeWidth) { + extended = adaptor.getOperand2(); } else { - extended = rewriter.template create(loc, dstType, - adaptor.getOperand2()); + return failure(); } + Value result = rewriter.template create( loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(operation, result); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" @@ -30,8 +31,10 @@ #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/Attributes.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Error.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/SourceMgr.h" @@ -2448,6 +2451,45 @@ return success(); } +//===----------------------------------------------------------------------===// +// Verifiers for sign extension ops +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyExtOp(ExtOp op) { + IntegerType inputType, outputType; + if (isCompatibleVectorType(op.getArg().getType())) { + if (!isCompatibleVectorType(op.getResult().getType())) + return op.emitError( + "input type is a vector but output type is an integer"); + if (getVectorNumElements(op.getArg().getType()) != + getVectorNumElements(op.getResult().getType())) + return op.emitError("input and output vectors are of incompatible shape"); + // Because this is a CastOp, the element of vectors is guaranteed to be an + // integer. + inputType = cast(getVectorElementType(op.getArg().getType())); + outputType = + cast(getVectorElementType(op.getResult().getType())); + } else { + // Because this is a CastOp and arg is not a vector, arg is guaranteed to be + // an integer. + inputType = cast(op.getArg().getType()); + outputType = dyn_cast(op.getResult().getType()); + if (!outputType) + return op.emitError( + "input type is an integer but output type is a vector"); + } + + if (outputType.getWidth() <= inputType.getWidth()) + return op.emitError("integer width of the output type is smaller or " + "equal to the integer width of the input type"); + return success(); +} + +LogicalResult SExtOp::verify() { return verifyExtOp(*this); } + +LogicalResult ZExtOp::verify() { return verifyExtOp(*this); } + //===----------------------------------------------------------------------===// // Folder and verifier for LLVM::BitcastOp //===----------------------------------------------------------------------===// @@ -2564,6 +2606,55 @@ return {}; } +//===----------------------------------------------------------------------===// +// Folder for LLVM::ZExtOp +//===----------------------------------------------------------------------===// + +OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) { + auto arg = dyn_cast_or_null(adaptor.getArg()); + if (!arg) + return {}; + + size_t targetSize = cast(getType()).getWidth(); + return IntegerAttr::get(getType(), arg.getValue().zext(targetSize)); +} + +//===----------------------------------------------------------------------===// +// Folder for for LLVM::ShlOp +//===----------------------------------------------------------------------===// + +OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) { + auto rhs = dyn_cast_or_null(adaptor.getRhs()); + if (!rhs) + return {}; + + if (rhs.getValue().getZExtValue() >= + getLhs().getType().getIntOrFloatBitWidth()) + return {}; // TODO: Fold into poison. + + auto lhs = dyn_cast_or_null(adaptor.getLhs()); + if (!lhs) + return {}; + + return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue())); +} + +//===----------------------------------------------------------------------===// +// Folder for for LLVM::OrOp +//===----------------------------------------------------------------------===// + +OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) { + auto lhs = dyn_cast_or_null(adaptor.getLhs()); + if (!lhs) + return {}; + + auto rhs = dyn_cast_or_null(adaptor.getRhs()); + if (!rhs) + return {}; + + return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue()); +} + //===----------------------------------------------------------------------===// // Utilities for LLVM::MetadataOp //===----------------------------------------------------------------------===// @@ -3103,6 +3194,11 @@ return verifyParameterAttribute(op, resType, resAttr); } +Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return builder.create(loc, type, value); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/constant-folding.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(canonicalize))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @zext_basic +llvm.func @zext_basic() -> i64 { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.zext %0 : i32 to i64 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: llvm.return %[[RES]] : i64 + llvm.return %1 : i64 +} + +// CHECK-LABEL: llvm.func @zext_neg +llvm.func @zext_neg() -> i64 { + %0 = llvm.mlir.constant(-1 : i32) : i32 + %1 = llvm.zext %0 : i32 to i64 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(4294967295 : i64) : i64 + // CHECK: llvm.return %[[RES]] : i64 + llvm.return %1 : i64 +} + +// ----- + +// CHECK-LABEL: llvm.func @shl_basic +llvm.func @shl_basic() -> i32 { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %2 = llvm.shl %0, %1 : i32 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %2 : i32 +} + +// CHECK-LABEL: llvm.func @shl_multiple +llvm.func @shl_multiple() -> i32 { + %0 = llvm.mlir.constant(45 : i32) : i32 + %1 = llvm.mlir.constant(7 : i32) : i32 + %2 = llvm.shl %0, %1 : i32 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(5760 : i32) : i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @or_basic +llvm.func @or_basic() -> i32 { + %0 = llvm.mlir.constant(5 : i32) : i32 + %1 = llvm.mlir.constant(9 : i32) : i32 + %2 = llvm.or %0, %1 : i32 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %2 : i32 +} diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1423,3 +1423,38 @@ // expected-error@+1 {{only zero-initializer allowed for target extension types}} %0 = llvm.mlir.constant(42 : index) : !llvm.target<"spirv.Event"> } + +// ----- + +func.func @invalid_zext_target_size(%arg: i32) { + // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}} + %0 = llvm.zext %arg : i32 to i16 +} + +// ----- + +func.func @invalid_zext_target_size_equal(%arg: i32) { + // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}} + %0 = llvm.zext %arg : i32 to i32 +} + +// ----- + +func.func @invalid_zext_target_size_vector(%arg: vector<1xi32>) { + // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}} + %0 = llvm.zext %arg : vector<1xi32> to vector<1xi16> +} + +// ----- + +func.func @invalid_zext_target_shape(%arg: vector<1xi32>) { + // expected-error@+1 {{input and output vectors are of incompatible shape}} + %0 = llvm.zext %arg : vector<1xi32> to vector<2xi64> +} + +// ----- + +func.func @invalid_zext_target_type(%arg: i32) { + // expected-error@+1 {{input type is an integer but output type is a vector}} + %0 = llvm.zext %arg : i32 to vector<1xi64> +} diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir --- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir @@ -1,12 +1,11 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @basic_memset -llvm.func @basic_memset() -> i32 { +// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8) +llvm.func @basic_memset(%memset_value: i8) -> i32 { %0 = llvm.mlir.constant(1 : i32) : i32 %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr - %memset_value = llvm.mlir.constant(42 : i8) : i8 %memset_len = llvm.mlir.constant(4 : i32) : i32 - // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () @@ -24,36 +23,27 @@ // ----- -// CHECK-LABEL: llvm.func @allow_dynamic_value_memset -// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8) -llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 { +// CHECK-LABEL: llvm.func @basic_memset_constant +llvm.func @basic_memset_constant() -> i32 { %0 = llvm.mlir.constant(1 : i32) : i32 %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 %memset_len = llvm.mlir.constant(4 : i32) : i32 - // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - // CHECK-NOT: "llvm.intr.memset" - // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32 - // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]] - // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]] - // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]] - // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]] - // CHECK-NOT: "llvm.intr.memset" %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 - // CHECK: llvm.return %[[VALUE_32]] : i32 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32 + // CHECK: llvm.return %[[RES]] : i32 llvm.return %2 : i32 } // ----- // CHECK-LABEL: llvm.func @exotic_target_memset -llvm.func @exotic_target_memset() -> i40 { +// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8) +llvm.func @exotic_target_memset(%memset_value: i8) -> i40 { %0 = llvm.mlir.constant(1 : i32) : i32 %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr - %memset_value = llvm.mlir.constant(42 : i8) : i8 %memset_len = llvm.mlir.constant(5 : i32) : i32 - // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40 // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40 // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40 @@ -74,6 +64,21 @@ // ----- +// CHECK-LABEL: llvm.func @exotic_target_memset_constant +llvm.func @exotic_target_memset_constant() -> i40 { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(5 : i32) : i32 + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40 + // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40 + // CHECK: llvm.return %[[RES]] : i40 + llvm.return %2 : i40 +} + +// ----- + // CHECK-LABEL: llvm.func @no_volatile_memset llvm.func @no_volatile_memset() -> i32 { // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32