diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -7724,6 +7724,8 @@ = fadd float 4.0, %var ; yields float:result = 4.0 + %var +.. _i_sub: + '``sub``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -7819,6 +7821,8 @@ = fsub float 4.0, %var ; yields float:result = 4.0 - %var = fsub float -0.0, %val ; yields float:result = -%var +.. _i_mul: + '``mul``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -7913,6 +7917,8 @@ = fmul float 4.0, %var ; yields float:result = 4.0 * %var +.. _i_udiv: + '``udiv``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -7959,6 +7965,8 @@ = udiv i32 4, %var ; yields i32:result = 4 / %var +.. _i_sdiv: + '``sdiv``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8047,6 +8055,8 @@ = fdiv float 4.0, %var ; yields float:result = 4.0 / %var +.. _i_urem: + '``urem``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8091,6 +8101,8 @@ = urem i32 4, %var ; yields i32:result = 4 % %var +.. _i_srem: + '``srem``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8204,6 +8216,8 @@ operands of the same type, execute an operation on them, and produce a single value. The resulting value is the same type as its operands. +.. _i_shl: + '``shl``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -8256,6 +8270,9 @@ = shl i32 1, 32 ; undefined = shl <2 x i32> < i32 1, i32 1>, < i32 1, i32 2> ; yields: result=<2 x i32> < i32 2, i32 4> +.. _i_lshr: + + '``lshr``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8305,6 +8322,8 @@ = lshr i32 1, 32 ; undefined = lshr <2 x i32> < i32 -2, i32 4>, < i32 1, i32 2> ; yields: result=<2 x i32> < i32 0x7FFFFFFF, i32 1> +.. _i_ashr: + '``ashr``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8355,6 +8374,8 @@ = ashr i32 1, 32 ; undefined = ashr <2 x i32> < i32 -2, i32 4>, < i32 1, i32 3> ; yields: result=<2 x i32> < i32 -1, i32 0> +.. _i_and: + '``and``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -8404,6 +8425,8 @@ = and i32 15, 40 ; yields i32:result = 8 = and i32 4, 8 ; yields i32:result = 0 +.. _i_or: + '``or``' Instruction ^^^^^^^^^^^^^^^^^^^^ @@ -8453,6 +8476,8 @@ = or i32 15, 40 ; yields i32:result = 47 = or i32 4, 8 ; yields i32:result = 12 +.. _i_xor: + '``xor``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -15256,6 +15281,770 @@ after performing the required machine specific adjustments. The pointer returned can then be :ref:`bitcast and executed `. + +.. _int_vp: + +Vector Predication Intrinsics +----------------------------- +VP intrinsics are intended for predicated SIMD/vector code. A typical VP +operation takes a vector mask and an explicit vector length parameter as in: + +:: + + llvm.vp..*( %x, %y, %mask, i32 %evl) + +The vector mask parameter (%mask) always has a vector of `i1` type, for example +`<32 x i1>`. The explicit vector length parameter always has the type `i32` and +is an unsigned integer value. The explicit vector length parameter (%evl) is in +the range: + +:: + + 0 <= %evl <= W, where W is the (total) number of vector elements + +Note that for :ref:`scalable vector types ` ``W`` is the runtime +length of the vector. + +The VP intrinsic has undefined behavior if ``%evl > W``. The explicit vector +length (%evl) creates a mask, %EVLmask, with all elements ``0 <= i < %evl`` set +to True, and all other lanes ``%evl <= i < W`` to False. A new mask %M is +calculated with an element-wise AND from %mask and %EVLmask: + +:: + + M = %mask AND %EVLmask + +A vector operation ```` on vectors ``A`` and ``B`` calculates: + +:: + + A B = { A[i] B[i] M[i] = True, and + { undef otherwise + +Optimization Hint +^^^^^^^^^^^^^^^^^ + +Some targets, such as AVX512, do not support the %evl parameter in hardware. +The use of an effective %evl is discouraged for those targets. The function +``TargetTransformInfo::hasActiveVectorLength()`` returns true when the target +has native support for %evl. + + +.. _int_vp_add: + +'``llvm.vp.add.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.add.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.add.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.add.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated integer addition of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.add``' intrinsic performs integer addition (:ref:`add `) +of the first and second vector operand on each enabled lane. The result on +disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.add.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = add <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + +.. _int_vp_sub: + +'``llvm.vp.sub.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.sub.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.sub.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.sub.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated integer subtraction of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.sub``' intrinsic performs integer subtraction +(:ref:`sub `) of the first and second vector operand on each enabled +lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.sub.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = sub <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_mul: + +'``llvm.vp.mul.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.mul.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.mul.nxv46i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.mul.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated integer multiplication of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" +The '``llvm.vp.mul``' intrinsic performs integer multiplication +(:ref:`mul `) of the first and second vector operand on each enabled +lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.mul.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = mul <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_sdiv: + +'``llvm.vp.sdiv.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.sdiv.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.sdiv.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.sdiv.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated, signed division of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.sdiv``' intrinsic performs signed division (:ref:`sdiv `) +of the first and second vector operand on each enabled lane. The result on +disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.sdiv.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = sdiv <4 x i32> %a, %b + %also.r = select <4 x ii> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_udiv: + +'``llvm.vp.udiv.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.udiv.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.udiv.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.udiv.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated, unsigned division of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.udiv``' intrinsic performs unsigned division +(:ref:`udiv `) of the first and second vector operand on each enabled +lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.udiv.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = udiv <4 x i32> %a, %b + %also.r = select <4 x ii> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_srem: + +'``llvm.vp.srem.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.srem.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.srem.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.srem.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated computations of the signed remainder of two integer vectors. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.srem``' intrinsic computes the remainder of the signed division +(:ref:`srem `) of the first and second vector operand on each enabled +lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.srem.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = srem <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_urem: + +'``llvm.vp.urem.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.urem.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.urem.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.urem.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated computation of the unsigned remainder of two integer vectors. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.urem``' intrinsic computes the remainder of the unsigned division +(:ref:`urem `) of the first and second vector operand on each enabled +lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.urem.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = urem <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_ashr: + +'``llvm.vp.ashr.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.ashr.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.ashr.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.ashr.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated arithmetic right-shift. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.ashr``' intrinsic computes the arithmetic right shift +(:ref:`ashr `) of the first operand by the second operand on each +enabled lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.ashr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = ashr <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_lshr: + + +'``llvm.vp.lshr.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.lshr.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.lshr.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.lshr.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated logical right-shift. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.lshr``' intrinsic computes the logical right shift +(:ref:`lshr `) of the first operand by the second operand on each +enabled lane. The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.lshr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = lshr <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_shl: + +'``llvm.vp.shl.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.shl.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.shl.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.shl.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated left shift. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.shl``' intrinsic computes the left shift (:ref:`shl `) of +the first operand by the second operand on each enabled lane. The result on +disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.shl.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = shl <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_or: + +'``llvm.vp.or.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.or.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.or.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.or.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated or. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.or``' intrinsic performs a bitwise or (:ref:`or `) of the +first two operands on each enabled lane. The result on disabled lanes is +undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.or.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = or <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_and: + +'``llvm.vp.and.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.and.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.and.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.and.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated and. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.and``' intrinsic performs a bitwise and (:ref:`and `) of +the first two operands on each enabled lane. The result on disabled lanes is +undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.and.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = and <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_xor: + +'``llvm.vp.xor.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.xor.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare @llvm.vp.xor.nxv4i32 ( , , , i32 ) + declare <256 x i64> @llvm.vp.xor.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated, bitwise xor. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The +third operand is the vector mask and has the same number of elements as the +result vector type. The fourth operand is the explicit vector length of the +operation. + +Semantics: +"""""""""" + +The '``llvm.vp.xor``' intrinsic performs a bitwise xor (:ref:`xor `) of +the first two operands on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.xor.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl) + ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r + + %t = xor <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_compose: + +'``llvm.vp.compose.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x float> @llvm.vp.compose.v16f32 (<16 x float> , <16 x float> , i32 , i32 ) + +Overview: +""""""""" + +The compose intrinsic blends two input vectors based on a pivot value. + + +Arguments: +"""""""""" + +The first operand is the vector whose elements are selected below the pivot. The second operand is the vector whose values are selected starting from the pivot position. The third operand is the pivot value. The fourth operand it the explicit vector length of the operation + + +Semantics: +"""""""""" + +The '``llvm.vp.compose``' intrinsic is designed for conditional blending of two vectors based on a pivot number. All lanes below the pivot are taken from the first operand, all elements at greated and equal positions are taken from the second operand. It is useful for targets that support an explicit vector length and guarantee that vector instructions preserve the contents of vector registers above the AVL of the operation. Other targets may support this intrinsic differently, for example by lowering it into a select with a bitmask that represents the pivot comparison. +The result of this operation is equivalent to a select with an equivalent predicate mask based on the pivot operand. However, as for all VP intrinsics all lanes above the explicit vector length are undefined. + + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.compose.v4i32(<4 x i32> %a, <4 x i32> %b, i32 %pivot, i32 %evl) + ;; lanes of %r at positions >= %evl are undef + + ;; except for %r is equivalent to %also.r + %tmp = insertelement <4 x i32> undef, %pivot, 0 + %pivot.splat = shufflevector <4 x i32> %tmp, <4 x i32> undef, <4 x i32> zeroinitializer + %pivot.mask = icmp ult i1 <4 x i32> , %pivot.splat + %also.r = select <4 x i1> %pivot.mask, <4 x i32> %a, <4 x i32> %b + + + +.. _int_vp_select: + +'``llvm.vp.select.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.select.v16i32 (<16 x i1> , <16 x i32> , <16 x i32> , i32 ) + declare <256 x double> @llvm.vp.select.v256f64 (<256 x i1> , <256 x double> , <256 x double> , i32 ) + +Overview: +""""""""" + +Conditional select with an explicit vector length. + + +Arguments: +"""""""""" + +The first three operand and the result are vector types of the same length. The second and third operand, and the result have the same vector type. The fourth operand is the explicit vector length. + +Semantics: +"""""""""" + +The '``llvm.vp.select``' intrinsic performs conditional select (:ref:`select `) of the second and thirs vector operand on each enabled lane. +If the explicit vector length (the fourth operand) is effective, the result is undefined on lanes at positions greater-equal-than the explicit vector length. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.select.v4i32(<4 x i1> %mask, <4 x i32> %onTrue, <4 x i32> %onFalse, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %also.r = select <4 x i1> %mask, <4 x i32> %onTrue, <4 x i32> %onFalse + + + .. _int_mload_mstore: Masked Vector Load and Store Intrinsics diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -53,6 +53,10 @@ class Value; class MDNode; class BinaryOperator; +class VPIntrinsic; +namespace PatternMatch { + struct PredicatedContext; +} /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative @@ -138,6 +142,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); +/// Given operands for an FSub, fold the result or return null. +Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); +Value *SimplifyPredicatedFSubInst(Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q, + PatternMatch::PredicatedContext & PC); + /// Given operands for an FMul, fold the result or return null. Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); @@ -263,6 +274,9 @@ /// If not, this returns null. Value *SimplifyFreezeInst(Value *Op, const SimplifyQuery &Q); +/// Given a VP intrinsic function, fold the result or return null. +Value *SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q); + /// See if we can compute a simplified version of this instruction. If not, /// return null. Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -45,6 +45,7 @@ class Function; class GlobalValue; class IntrinsicInst; +class PredicatedInstruction; class LoadInst; class LoopAccessInfo; class Loop; @@ -1156,6 +1157,15 @@ bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags Flags) const; + /// \returns True if the vector length parameter should be folded into the + /// vector mask. + bool + shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PredInst) const; + + /// \returns False if this VP op should be replaced by a non-VP op or an + /// unpredicated op plus a select. + bool supportsVPOperation(const PredicatedInstruction &PredInst) const; + /// \returns True if the target wants to expand the given reduction intrinsic /// into a shuffle sequence. bool shouldExpandReduction(const IntrinsicInst *II) const; @@ -1164,6 +1174,15 @@ /// to a stack reload. unsigned getGISelRematGlobalCost() const; + /// \name Vector Predication Information + /// @{ + /// Whether the target supports the %evl parameter of VP intrinsic efficiently in hardware. + /// (see LLVM Language Reference - "Vector Predication Intrinsics") + /// Use of %evl is discouraged when that is not the case. + bool hasActiveVectorLength() const; + + /// @} + /// @} private: @@ -1409,10 +1428,15 @@ virtual unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize, unsigned ChainSizeInBytes, VectorType *VecTy) const = 0; + virtual bool shouldFoldVectorLengthIntoMask( + const PredicatedInstruction &PredInst) const = 0; + virtual bool + supportsVPOperation(const PredicatedInstruction &PredInst) const = 0; virtual bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags) const = 0; virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0; virtual unsigned getGISelRematGlobalCost() const = 0; + virtual bool hasActiveVectorLength() const = 0; virtual int getInstructionLatency(const Instruction *I) = 0; }; @@ -1888,6 +1912,14 @@ VectorType *VecTy) const override { return Impl.getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); } + bool shouldFoldVectorLengthIntoMask( + const PredicatedInstruction &PredInst) const override { + return Impl.shouldFoldVectorLengthIntoMask(PredInst); + } + bool + supportsVPOperation(const PredicatedInstruction &PredInst) const override { + return Impl.supportsVPOperation(PredInst); + } bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags Flags) const override { return Impl.useReductionIntrinsic(Opcode, Ty, Flags); @@ -1900,6 +1932,10 @@ return Impl.getGISelRematGlobalCost(); } + bool hasActiveVectorLength() const override { + return Impl.hasActiveVectorLength(); + } + int getInstructionLatency(const Instruction *I) override { return Impl.getInstructionLatency(I); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -612,6 +612,15 @@ return VF; } + bool + shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PredInst) const { + return true; + } + + bool supportsVPOperation(const PredicatedInstruction &PredInst) const { + return false; + } + bool useReductionIntrinsic(unsigned Opcode, Type *Ty, TTI::ReductionFlags Flags) const { return false; @@ -625,6 +634,10 @@ return 1; } + bool hasActiveVectorLength() const { + return false; + } + protected: // Obtain the minimum required size to hold the value (without the sign) // In case of a vector it returns the min required size for one element. diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -633,6 +633,9 @@ ATTR_KIND_NOFREE = 62, ATTR_KIND_NOSYNC = 63, ATTR_KIND_SANITIZE_MEMTAG = 64, + ATTR_KIND_MASK = 65, + ATTR_KIND_VECTORLENGTH = 66, + ATTR_KIND_PASSTHRU = 67, }; enum ComdatSelectionKindCodes { diff --git a/llvm/include/llvm/CodeGen/ExpandVectorPredication.h b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h @@ -0,0 +1,23 @@ +//===----- ExpandVectorPredication.h - Expand vector predication --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H +#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class ExpandVectorPredicationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // end namespace llvm + +#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -200,6 +200,7 @@ /// Simple integer binary arithmetic operators. ADD, SUB, MUL, SDIV, UDIV, SREM, UREM, + VP_ADD, VP_SUB, VP_MUL, VP_SDIV, VP_UDIV, VP_SREM, VP_UREM, /// SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing /// a signed/unsigned value of type i[2*N], and return the full value as @@ -298,6 +299,7 @@ /// Simple binary floating point operators. FADD, FSUB, FMUL, FDIV, FREM, + VP_FADD, VP_FSUB, VP_FMUL, VP_FDIV, VP_FREM, /// Constrained versions of the binary floating point operators. /// These will be lowered to the simple operators before final selection. @@ -359,6 +361,7 @@ /// FMA - Perform a * b + c with no intermediate rounding step. FMA, + VP_FMA, /// FMAD - Perform a * b + c, while getting the same result as the /// separately rounded operations. @@ -425,6 +428,19 @@ /// in terms of the element size of VEC1/VEC2, not in terms of bytes. VECTOR_SHUFFLE, + /// VP_VSHIFT(VEC1, AMOUNT, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. AMOUNT is an integer value. The returned vector is equivalent + /// to VEC1 shifted by AMOUNT (RETURNED_VEC[idx] = VEC1[idx + AMOUNT]). + VP_VSHIFT, + + /// VP_COMPRESS(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_COMPRESS, + + /// VP_EXPAND(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_EXPAND, + /// SCALAR_TO_VECTOR(VAL) - This represents the operation of loading a /// scalar value into element 0 of the resultant vector type. The top /// elements 1 to N-1 of the N-element vector are undefined. The type @@ -451,6 +467,7 @@ /// Bitwise operators - logical and, logical or, logical xor. AND, OR, XOR, + VP_AND, VP_OR, VP_XOR, /// ABS - Determine the unsigned absolute value of a signed integer value of /// the same bitwidth. @@ -474,6 +491,7 @@ /// fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW))) /// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW)) SHL, SRA, SRL, ROTL, ROTR, FSHL, FSHR, + VP_SHL, VP_SRA, VP_SRL, /// Byte Swap and Counting operators. BSWAP, CTTZ, CTLZ, CTPOP, BITREVERSE, @@ -493,6 +511,14 @@ /// change the condition type in order to match the VSELECT node using a /// pattern. The condition follows the BooleanContent format of the target. VSELECT, + VP_SELECT, + + /// Select with an integer pivot (op #0) and two vector operands (ops #1 + /// and #2), returning a vector result. Op #3 is the vector length, all + /// vectors have the same length. + /// Vector element below the pivot (op #0) are taken from op #1, elements + /// at positions greater-equal than the pivot are taken from op #2. + VP_COMPOSE, /// Select with condition operator - This selects between a true value and /// a false value (ops #2 and #3) based on the boolean result of comparing @@ -507,6 +533,7 @@ /// them with (op #2) as a CondCodeSDNode. If the operands are vector types /// then the result type must also be a vector type. SETCC, + VP_SETCC, /// Like SetCC, ops #0 and #1 are the LHS and RHS operands to compare, but /// op #2 is a boolean indicating if there is an incoming carry. This @@ -543,6 +570,8 @@ /// depends on the first letter) to floating point. SINT_TO_FP, UINT_TO_FP, + VP_SINT_TO_FP, + VP_UINT_TO_FP, /// SIGN_EXTEND_INREG - This operator atomically performs a SHL/SRA pair to /// sign extend a small value in a large integer register (e.g. sign @@ -589,6 +618,8 @@ /// the FP value cannot fit in the integer type, the results are undefined. FP_TO_SINT, FP_TO_UINT, + VP_FP_TO_SINT, + VP_FP_TO_UINT, /// X = FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating point type /// down to the precision of the destination VT. TRUNC is a flag, which is @@ -614,6 +645,7 @@ /// X = FP_EXTEND(Y) - Extend a smaller FP type into a larger FP type. FP_EXTEND, + VP_FP_EXTEND, /// BITCAST - This operator converts between integer, vector and FP /// values, as if the value was stored to memory with one type and loaded @@ -649,6 +681,10 @@ FCEIL, FTRUNC, FRINT, FNEARBYINT, FROUND, FFLOOR, LROUND, LLROUND, LRINT, LLRINT, + VP_FNEG, VP_FABS, VP_FSQRT, VP_FCBRT, VP_FSIN, VP_FCOS, VP_FPOWI, VP_FPOW, + VP_FLOG, VP_FLOG2, VP_FLOG10, VP_FEXP, VP_FEXP2, + VP_FCEIL, VP_FTRUNC, VP_FRINT, VP_FNEARBYINT, VP_FROUND, VP_FFLOOR, + VP_LROUND, VP_LLROUND, VP_LRINT, VP_LLRINT, /// FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two /// values. // @@ -657,6 +693,7 @@ /// /// The return value of (FMINNUM 0.0, -0.0) could be either 0.0 or -0.0. FMINNUM, FMAXNUM, + VP_FMINNUM, VP_FMAXNUM, /// FMINNUM_IEEE/FMAXNUM_IEEE - Perform floating-point minimum or maximum on /// two values, following the IEEE-754 2008 definition. This differs from @@ -894,6 +931,7 @@ // Val, OutChain = MLOAD(BasePtr, Mask, PassThru) // OutChain = MSTORE(Value, BasePtr, Mask) MLOAD, MSTORE, + VP_LOAD, VP_STORE, // Masked gather and scatter - load and store operations for a vector of // random addresses with additional mask operand that prevents memory @@ -905,6 +943,17 @@ // The Index operand can have more vector elements than the other operands // due to type legalization. The extra elements are ignored. MGATHER, MSCATTER, + + // VP gather and scatter - load and store operations for a vector of + // random addresses with additional mask and vector length operand that + // prevents memory accesses to the masked-off lanes. + // + // Val, OutChain = VP_GATHER(InChain, BasePtr, Index, Scale, Mask, EVL) + // OutChain = VP_SCATTER(InChain, Value, BasePtr, Index, Scale, Mask, EVL) + // + // The Index operand can have more vector elements than the other operands + // due to type legalization. The extra elements are ignored. + VP_GATHER, VP_SCATTER, /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. @@ -947,6 +996,14 @@ VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR, VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN, + VP_REDUCE_FADD, VP_REDUCE_FMUL, + VP_REDUCE_ADD, VP_REDUCE_MUL, + VP_REDUCE_AND, VP_REDUCE_OR, VP_REDUCE_XOR, + VP_REDUCE_SMAX, VP_REDUCE_SMIN, VP_REDUCE_UMAX, VP_REDUCE_UMIN, + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + VP_REDUCE_FMAX, VP_REDUCE_FMIN, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END @@ -1138,6 +1195,20 @@ /// SETCC_INVALID if it is not possible to represent the resultant comparison. CondCode getSetCCAndOperation(CondCode Op1, CondCode Op2, EVT Type); + /// Return the mask operand of this VP SDNode. + /// Otherwise, return -1. + int GetMaskPosVP(unsigned OpCode); + + /// Return the vector length operand of this VP SDNode. + /// Otherwise, return -1. + int GetVectorLengthPosVP(unsigned OpCode); + + /// Translate this VP OpCode to an unpredicated instruction OpCode. + unsigned GetFunctionOpCodeForVP(unsigned VPOpCode, bool hasFPExcept); + + /// Translate this non-VP Opcode to its corresponding VP Opcode + unsigned GetVPForFunctionOpCode(unsigned OpCode); + } // end llvm::ISD namespace } // end llvm namespace diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -444,6 +444,11 @@ /// shuffles. FunctionPass *createExpandReductionsPass(); + /// This pass expands the vector predication intrinsics into unpredicated + /// instructions with selects or just the explicit vector length into the + /// predicate mask. + FunctionPass *createExpandVectorPredicationPass(); + // This pass expands memcmp() to load/stores. FunctionPass *createExpandMemCmpPass(); diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1049,6 +1049,20 @@ return getNode(ISD::SETCC, DL, VT, LHS, RHS, getCondCode(Cond)); } + /// Helper function to make it easier to build VP_SetCC's if you just have an + /// ISD::CondCode instead of an SDValue. + SDValue getVPSetCC(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, + ISD::CondCode Cond, SDValue Mask, SDValue EVL) { + assert(LHS.getValueType().isVector() == RHS.getValueType().isVector() && + "Cannot compare scalars to vectors"); + assert(LHS.getValueType().isVector() == VT.isVector() && + "Cannot compare scalars to vectors"); + assert(Cond != ISD::SETCC_INVALID && + "Cannot create a setCC of an invalid node."); + return getNode(ISD::VP_SETCC, DL, VT, LHS, RHS, getCondCode(Cond), Mask, + EVL); + } + /// Helper function to make it easier to build Select's if you just have /// operands and don't want to check for vector. SDValue getSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS, @@ -1190,6 +1204,20 @@ SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + /// Returns sum of the base pointer and offset. + SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, ISD::LoadExtType); + SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, bool IsTruncating = false); + SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + SDValue getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Base, SDValue Offset, SDValue Mask, SDValue Src0, EVT MemVT, MachineMemOperand *MMO, ISD::MemIndexedMode AM, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -550,6 +550,7 @@ friend class LSBaseSDNode; friend class MaskedLoadStoreSDNode; friend class MaskedGatherScatterSDNode; + friend class VPGatherScatterSDNode; uint16_t : NumMemSDNodeBits; @@ -565,6 +566,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class VPLoadSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -575,6 +577,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class VPStoreSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -705,6 +708,66 @@ } } + /// Test whether this is a vector predicated node. + bool isVP() const { + switch (NodeType) { + default: + return false; + case ISD::VP_LOAD: + case ISD::VP_STORE: + case ISD::VP_GATHER: + case ISD::VP_SCATTER: + + case ISD::VP_FNEG: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + + case ISD::VP_FMA: + + case ISD::VP_ADD: + case ISD::VP_MUL: + case ISD::VP_SUB: + case ISD::VP_SRA: + case ISD::VP_SRL: + case ISD::VP_SHL: + case ISD::VP_UDIV: + case ISD::VP_SDIV: + case ISD::VP_UREM: + case ISD::VP_SREM: + + case ISD::VP_EXPAND: + case ISD::VP_COMPRESS: + case ISD::VP_VSHIFT: + case ISD::VP_SETCC: + case ISD::VP_COMPOSE: + + case ISD::VP_AND: + case ISD::VP_XOR: + case ISD::VP_OR: + + case ISD::VP_REDUCE_ADD: + case ISD::VP_REDUCE_SMIN: + case ISD::VP_REDUCE_SMAX: + case ISD::VP_REDUCE_UMIN: + case ISD::VP_REDUCE_UMAX: + + case ISD::VP_REDUCE_MUL: + case ISD::VP_REDUCE_AND: + case ISD::VP_REDUCE_OR: + case ISD::VP_REDUCE_FADD: + case ISD::VP_REDUCE_FMUL: + case ISD::VP_REDUCE_FMIN: + case ISD::VP_REDUCE_FMAX: + + return true; + } + } + + /// Test if this node has a post-isel opcode, directly /// corresponding to a MachineInstr opcode. bool isMachineOpcode() const { return NodeType < 0; } @@ -1416,6 +1479,10 @@ N->getOpcode() == ISD::MSTORE || N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER || + N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE || + N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -2277,6 +2344,96 @@ } }; +/// This base class is used to represent VP_LOAD and VP_STORE nodes +class VPLoadStoreSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // VPLoadSDNode (Chain, ptr, mask, VLen) + // VPStoreSDNode (Chain, data, ptr, mask, VLen) + // Mask is a vector of i1 elements, Vlen is i32 + const SDValue &getBasePtr() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2); + } + const SDValue &getMask() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 2 : 3); + } + const SDValue &getVectorLength() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE; + } +}; + +/// This class is used to represent a VP_LOAD node +class VPLoadSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::LoadExtType ETy, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = false; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getVectorLength() const { return getOperand(3); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + +/// This class is used to represent a VP_STORE node +class VPStoreSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + bool isTrunc, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = isTrunc; + StoreSDNodeBits.IsCompressing = false; + } + + /// Return true if this is a truncating store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getMask() const { return getOperand(3); } + const SDValue &getVectorLength() const { return getOperand(4); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: @@ -2385,6 +2542,85 @@ } }; +/// This is a base class used to represent +/// VP_GATHER and VP_SCATTER nodes +/// +class VPGatherScatterSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + LSBaseSDNodeBits.AddressingMode = IndexType; + assert(getIndexType() == IndexType && "Value truncated"); + } + + /// How is Index applied to BasePtr when computing addresses. + ISD::MemIndexType getIndexType() const { + return static_cast(LSBaseSDNodeBits.AddressingMode); + } + bool isIndexScaled() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::UNSIGNED_SCALED); + } + bool isIndexSigned() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::SIGNED_UNSCALED); + } + + // In the both nodes address is Op1, mask is Op2: + // VPGatherSDNode (Chain, base, index, scale, mask, vlen) + // VPScatterSDNode (Chain, value, base, index, scale, mask, vlen) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 1 : 2); } + const SDValue &getIndex() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 2 : 3); } + const SDValue &getScale() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 3 : 4); } + const SDValue &getMask() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 4 : 5); } + const SDValue &getVectorLength() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 5 : 6); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER; + } +}; + +/// This class is used to represent an VP_GATHER node +/// +class VPGatherSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_GATHER, Order, dl, VTs, MemVT, MMO, IndexType) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER; + } +}; + +/// This class is used to represent an VP_SCATTER node +/// +class VPScatterSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_SCATTER, Order, dl, VTs, MemVT, MMO, IndexType) {} + + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_SCATTER; + } +}; + + /// This is a base class used to represent /// MGATHER and MSCATTER nodes /// diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td --- a/llvm/include/llvm/IR/Attributes.td +++ b/llvm/include/llvm/IR/Attributes.td @@ -139,6 +139,15 @@ /// Parameter is required to be a trivial constant. def ImmArg : EnumAttr<"immarg">; +/// Return value that is equal to this argument on enabled lanes (mask). +def Passthru : EnumAttr<"passthru">; + +/// Mask argument that applies to this function. +def Mask : EnumAttr<"mask">; + +/// Dynamic Vector Length argument of this function. +def VectorLength : EnumAttr<"vlen">; + /// Function can return twice. def ReturnsTwice : EnumAttr<"returns_twice">; diff --git a/llvm/include/llvm/IR/FPEnv.h b/llvm/include/llvm/IR/FPEnv.h --- a/llvm/include/llvm/IR/FPEnv.h +++ b/llvm/include/llvm/IR/FPEnv.h @@ -21,6 +21,9 @@ namespace llvm { +class LLVMContext; +class Value; + namespace fp { /// Rounding mode used for floating point operations. @@ -47,7 +50,7 @@ ebStrict ///< This corresponds to "fpexcept.strict". }; -} +} // namespace fp /// Returns a valid RoundingMode enumerator when given a string /// that is valid as input in constrained intrinsic rounding mode @@ -66,5 +69,11 @@ /// input in constrained intrinsic exception behavior metadata. Optional ExceptionBehaviorToStr(fp::ExceptionBehavior); -} +/// Return the IR Value representation of any ExceptionBehavior. +Value *GetConstrainedFPExcept(LLVMContext &, fp::ExceptionBehavior); + +/// Return the IR Value representation of any RoundingMode. +Value *GetConstrainedFPRounding(LLVMContext &, fp::RoundingMode); + +} // namespace llvm #endif diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -29,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -855,6 +856,23 @@ /// assume that the provided condition will be true. CallInst *CreateAssumption(Value *Cond); + /// Call an arithmetic VP intrinsic. + Instruction *CreateVectorPredicatedInst(unsigned OC, ArrayRef, + Instruction *FMFSource = nullptr, + const Twine &Name = ""); + + /// Call an comparison VP intrinsic. + Instruction *CreateVectorPredicatedCmp(CmpInst::Predicate Pred, + Value *FirstOp, Value *SndOp, Value *Mask, + Value *VectorLength, + const Twine &Name = ""); + + /// Call an comparison VP intrinsic. + Instruction *CreateVectorPredicatedReduce(Module &M, CmpInst::Predicate Pred, + Value *FirstOp, Value *SndOp, Value *Mask, + Value *VectorLength, + const Twine &Name = ""); + /// Create a call to the experimental.gc.statepoint intrinsic to /// start a new statepoint sequence. CallInst *CreateGCStatepointCall(uint64_t ID, uint32_t NumPatchBytes, @@ -1236,11 +1254,7 @@ if (Except.hasValue()) UseExcept = Except.getValue(); - Optional ExceptStr = ExceptionBehaviorToStr(UseExcept); - assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); - auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); - - return MetadataAsValue::get(Context, ExceptMDS); + return GetConstrainedFPExcept(Context, UseExcept); } Value *getConstrainedFPPredicate(CmpInst::Predicate Predicate) { diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -206,9 +206,175 @@ /// @} }; + /// This is the common base class for vector predication intrinsics. + class VPIntrinsic : public IntrinsicInst { + public: + enum class VPTypeToken : int8_t { + Returned = 0, // vectorized return type. + Vector = 1, // vector operand type + Pointer = 2, // vector pointer-operand type (memory op) + Mask = 3 // vector mask type + }; + + using TypeTokenVec = SmallVector; + using ShortTypeVec = SmallVector; + + /// \brief Declares a llvm.vp.* intrinsic in \p M that matches the parameters \p Params. + static Function* GetDeclarationForParams(Module *M, Intrinsic::ID, ArrayRef Params, Type* VecRetTy = nullptr); + + // Type tokens required to instantiate this intrinsic. + static TypeTokenVec GetTypeTokens(Intrinsic::ID); + + // whether the intrinsic has a rounding mode parameter (regardless of + // setting). + static bool HasRoundingModeParam(Intrinsic::ID VPID) { return GetRoundingModeParamPos(VPID) != None; } + // whether the intrinsic has a exception behavior parameter (regardless of + // setting). + static bool HasExceptionBehaviorParam(Intrinsic::ID VPID) { return GetExceptionBehaviorParamPos(VPID) != None; } + static Optional GetMaskParamPos(Intrinsic::ID IntrinsicID); + static Optional GetVectorLengthParamPos(Intrinsic::ID IntrinsicID); + static Optional + GetExceptionBehaviorParamPos(Intrinsic::ID IntrinsicID); + static Optional GetRoundingModeParamPos(Intrinsic::ID IntrinsicID); + // the llvm.vp.* intrinsic for this other kind of intrinsic. + static Intrinsic::ID GetForIntrinsic(Intrinsic::ID IntrinsicID); + static Intrinsic::ID GetForOpcode(unsigned OC); + + // Whether \p ID is a VP intrinsic ID. + static bool IsVPIntrinsic(Intrinsic::ID); + + /// TODO make this private! + /// \brief Generate the disambiguating type vec for this VP Intrinsic. + /// \returns A disamguating type vector to instantiate this intrinsic. + /// \p TTVec + /// Vector of disambiguating tokens. + /// \p VecRetTy + /// The return type of the intrinsic (optional) + /// \p VecPtrTy + /// The pointer operand type (optional) + /// \p VectorTy + /// The vector data type of the operation. + static VPIntrinsic::ShortTypeVec + EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type *VecRetTy, + Type *VecPtrTy, Type &VectorTy); + + /// set the mask parameter. + /// this asserts if the underlying intrinsic has no mask parameter. + void setMaskParam(Value *); + + /// set the vector length parameter. + /// this asserts if the underlying intrinsic has no vector length + /// parameter. + void setVectorLengthParam(Value *); + + /// \return the mask parameter or nullptr. + Value *getMaskParam() const; + + /// \return the vector length parameter or nullptr. + Value *getVectorLengthParam() const; + + /// \return whether the vector length param can be ignored. + bool canIgnoreVectorLengthParam() const; + + /// \return the alignment of the pointer used by this load/store/gather or scatter. + MaybeAlign getPointerAlignment() const; + // MaybeAlign setPointerAlignment(Align NewAlign); // TODO + + /// \return The pointer operand of this load,store, gather or scatter. + Value *getMemoryPointerParam() const; + static Optional GetMemoryPointerParamPos(Intrinsic::ID); + + /// \return The data (payload) operand of this store or scatter. + Value *getMemoryDataParam() const; + static Optional GetMemoryDataParamPos(Intrinsic::ID); + + /// \return The vector to reduce if this is a reduction operation. + Value *getReductionVectorParam() const; + static Optional GetReductionVectorParamPos(Intrinsic::ID VPID); + + /// \return The initial value of this is a reduction operation. + Value *getReductionAccuParam() const; + static Optional GetReductionAccuParamPos(Intrinsic::ID VPID); + + /// \return the static element count (vector number of elements) the vector + /// length parameter applies to. + ElementCount getStaticVectorLength() const; + + bool isUnaryOp() const; + static bool IsUnaryVPOp(Intrinsic::ID); + bool isBinaryOp() const; + static bool IsBinaryVPOp(Intrinsic::ID); + bool isTernaryOp() const; + static bool IsTernaryVPOp(Intrinsic::ID); + + /// \returns Whether this is a comparison operation. + bool isCompareOp() const; + static bool IsCompareVPOp(Intrinsic::ID); + + /// \returns The comparison predicate. + CmpInst::Predicate getCmpPredicate() const; + + // Contrained fp-math + // whether this is an fp op with non-standard rounding or exception + // behavior. + bool isConstrainedOp() const; + + // the specified rounding mode. + Optional getRoundingMode() const; + // the specified exception behavior. + Optional getExceptionBehavior() const; + + // llvm.vp.reduction.* + bool isReductionOp() const; + static bool IsVPReduction(Intrinsic::ID VPIntrin); + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return IsVPIntrinsic(I->getIntrinsicID()); + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + + /// \return The non-VP intrinsic that is functionally equivalent to this VP + /// intrinsic. + Intrinsic::ID getFunctionalIntrinsicID() const { + Intrinsic::ID IID = Intrinsic::not_intrinsic; + // Return a constrained intrinsic if this intrinsic does not operate in + // the standard fp environment. + if (isConstrainedOp()) { + IID = GetConstrainedIntrinsicForVP(getIntrinsicID()); + } + if (IID == Intrinsic::not_intrinsic) { + IID = GetFunctionalIntrinsicForVP(getIntrinsicID()); + } + return IID; + } + + /// \return The llvm.experimental.constrained.* intrinsic that is + /// functionally equivalent to this llvm.vp.* intrinsic. + static Intrinsic::ID GetConstrainedIntrinsicForVP(Intrinsic::ID VPID); + + /// \return The intrinsic that is + /// functionally equivalent to this llvm.vp.* intrinsic. + static Intrinsic::ID GetFunctionalIntrinsicForVP(Intrinsic::ID VPID); + + // Equivalent non-predicated opcode + unsigned getFunctionalOpcode() const { + if (isConstrainedOp()) { + return Instruction::Call; + } + return GetFunctionalOpcodeForVP(getIntrinsicID()); + } + + // Equivalent non-predicated opcode + static unsigned GetFunctionalOpcodeForVP(Intrinsic::ID ID); + }; + /// This is the common base class for constrained floating point intrinsics. class ConstrainedFPIntrinsic : public IntrinsicInst { public: + bool isUnaryOp() const; bool isTernaryOp() const; Optional getRoundingMode() const; diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -27,6 +27,10 @@ // effects. It may be CSE'd deleted if dead, etc. def IntrNoMem : IntrinsicProperty; +// IntrNoSync - Threads executing the intrinsic will not synchronize using +// memory or other means. +def IntrNoSync : IntrinsicProperty; + // IntrReadMem - This intrinsic only reads from memory. It does not write to // memory and has no other side effects. Therefore, it cannot be moved across // potentially aliasing stores. However, it can be reordered otherwise and can @@ -98,6 +102,25 @@ int ArgNo = argNo; } +// VectorLength - The specified argument is the Dynamic Vector Length of the +// operation. +class VectorLength : IntrinsicProperty { + int ArgNo = argNo; +} + +// Mask - The specified argument contains the per-lane mask of this +// intrinsic. Inputs on masked-out lanes must not affect the result of this +// intrinsic (except for the Passthru argument). +class Mask : IntrinsicProperty { + int ArgNo = argNo; +} +// Passthru - The specified argument contains the per-lane return value +// for this vector intrinsic where the mask is false. +// (requires the Mask attribute in the same function) +class Passthru : IntrinsicProperty { + int ArgNo = argNo; +} + def IntrNoReturn : IntrinsicProperty; def IntrWillReturn : IntrinsicProperty; @@ -1153,8 +1176,472 @@ def int_ptrmask: Intrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty, llvm_anyint_ty], [IntrNoMem, IntrSpeculatable, IntrWillReturn]>; +//===---------------- Vector Predication Intrinsics --------------===// + +// Memory Intrinsics +def int_vp_store : Intrinsic<[], + [ llvm_anyvector_ty, + LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture<1>, IntrNoSync, IntrArgMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +def int_vp_load : Intrinsic<[ llvm_anyvector_ty], + [ LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture<0>, IntrNoSync, IntrReadMem, IntrWillReturn, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_gather: Intrinsic<[ llvm_anyvector_ty], + [ LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrNoSync, IntrWillReturn, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_scatter: Intrinsic<[], + [ llvm_anyvector_ty, + LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, IntrNoSync, IntrWillReturn, Mask<2>, VectorLength<3> ]>; +// TODO allow IntrNoCapture for vectors of pointers + +// Reductions +let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, Mask<1>, VectorLength<2>] in { + def int_vp_reduce_add : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_mul : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_and : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_or : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_xor : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_smax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_smin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_umax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_umin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, Mask<2>, VectorLength<3>] in { + def int_vp_reduce_fadd : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorElementType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmul : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorElementType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +// Binary operators +let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in { + def int_vp_add : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_mul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_udiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_srem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_urem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Logical operators + def int_vp_ashr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_lshr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_shl : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_or : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_and : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_xor : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + +// Comparison +// TODO add signalling fcmp +// The last argument is the comparison predicate +def int_vp_icmp : Intrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], + [ llvm_anyvector_ty, + LLVMMatchType<0>, + llvm_i8_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrWillReturn, IntrNoSync, IntrNoMem, Mask<3>, VectorLength<4>, ImmArg<2> ]>; + +def int_vp_fcmp : Intrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], + [ llvm_anyvector_ty, + LLVMMatchType<0>, + llvm_i8_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrWillReturn, IntrNoSync, IntrNoMem, Mask<3>, VectorLength<4>, ImmArg<2> ]>; + + + +// Shuffle +def int_vp_vshift: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +def int_vp_expand: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<1>, VectorLength<2> ]>; + +def int_vp_compress: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrNoSync, IntrWillReturn, VectorLength<2> ]>; + +// Select +def int_vp_select : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [ IntrNoMem, IntrNoSync, IntrWillReturn, Passthru<2>, Mask<0>, VectorLength<3> ]>; + +// Compose +def int_vp_compose : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [ IntrNoMem, IntrNoSync, IntrWillReturn, VectorLength<3> ]>; + + + +// VP fp rounding and truncation +let IntrProperties = [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<2>, VectorLength<3> ] in { + + def int_vp_fptosi : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fptoui : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fpext : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + def int_vp_sitofp : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyint_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_uitofp : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyint_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + + +let IntrProperties = [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + def int_vp_fptrunc : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + +// VP single argument constrained intrinsics. +let IntrProperties = [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_sqrt : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sin : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_cos : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log10: Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log2 : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_exp : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_exp2 : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_rint : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_nearbyint : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_ceil : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_floor : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_round : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_trunc : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + + +// VP two argument constrained intrinsics. +let IntrProperties = [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<4>, VectorLength<5> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_powi : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_pow : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_maxnum : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_minnum : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + + +// VP standard fp-math intrinsics. +def int_vp_fneg : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<4>, VectorLength<5> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_fadd : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fsub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fmul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_frem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; +} + +def int_vp_fma : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrNoMem, IntrNoSync, IntrWillReturn, Mask<5>, VectorLength<6> ]>; + + + + //===-------------------------- Masked Intrinsics -------------------------===// -// +// TODO poised for deprecation (to be superseded by llvm.vp.* intrinsics) def int_masked_store : Intrinsic<[], [llvm_anyvector_ty, LLVMAnyPointerType>, llvm_i32_ty, @@ -1255,6 +1742,7 @@ ImmArg<3> ]>; //===------------------------ Reduction Intrinsics ------------------------===// +// TODO poised for deprecation (to be superseded by llvm.vp.*. intrinsics) // let IntrProperties = [IntrNoMem, IntrWillReturn] in { def int_experimental_vector_reduce_v2_fadd : Intrinsic<[llvm_anyfloat_ty], diff --git a/llvm/include/llvm/IR/MatcherCast.h b/llvm/include/llvm/IR/MatcherCast.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/MatcherCast.h @@ -0,0 +1,65 @@ +#ifndef LLVM_IR_MATCHERCAST_H +#define LLVM_IR_MATCHERCAST_H + +//===- MatcherCast.h - Match on the LLVM IR --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parameterized class hierachy for templatized pattern matching. +// +//===----------------------------------------------------------------------===// + + +namespace llvm { +namespace PatternMatch { + + +// type modification +template +struct MatcherCast { }; + +// whether the Value \p Obj behaves like a \p Class. +template +bool match_isa(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return isa(Obj); +} + +template +auto match_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + +template +auto match_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_MATCHERCAST_H + diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -41,22 +41,81 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/MatcherCast.h" + #include + namespace llvm { namespace PatternMatch { +// Use verbatim types in default (empty) context. +struct EmptyContext { + EmptyContext() {} + + EmptyContext(const Value *) {} + + EmptyContext(const EmptyContext & E) {} + + // reset this match context to be rooted at \p V + void reset(Value * V) {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { return true; } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(EmptyContext E) const { return true; } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(EmptyContext E) { return true; } + + // reset this context to \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match in the current context + template bool try_match(Val *V, const Pattern &P) { + return const_cast(P).match_context(V, *this); + } +}; + +template +struct MatcherCast { using ActualCastType = DestClass; }; + + + + + + +// match without (== empty) context template bool match(Val *V, const Pattern &P) { - return const_cast(P).match(V); + EmptyContext ECtx; + return const_cast(P).match_context(V, ECtx); +} + +// match pattern in a given context +template bool match(Val *V, const Pattern &P, MatchContext & MContext) { + return const_cast(P).match_context(V, MContext); } + + template struct OneUse_match { SubPattern_t SubPattern; OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + EmptyContext EContext; return match_context(V, EContext); + } + + template bool match_context(OpTy *V, MatchContext & MContext) { + return V->hasOneUse() && SubPattern.match_context(V, MContext); } }; @@ -65,7 +124,11 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + template bool match(ITy *V) { + EmptyContext EContext; return match_context(V, EContext); + } + template + bool match_context(ITy *V, MatchContext & MContext) { return match_isa(V); } }; /// Match an arbitrary value and ignore it. @@ -121,11 +184,17 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + MatchContext SubContext; + + if (L.match_context(V, SubContext) && MContext.acceptContext(SubContext)) { + MContext.mergeContext(SubContext); return true; - if (R.match(V)) + } + if (R.match_context(V, MContext)) { return true; + } return false; } }; @@ -136,9 +205,10 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (L.match_context(V, MContext)) + if (R.match_context(V, MContext)) return true; return false; } @@ -163,7 +233,8 @@ apint_match(const APInt *&Res, bool AllowUndef) : Res(Res), AllowUndef(AllowUndef) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -188,7 +259,8 @@ apfloat_match(const APFloat *&Res, bool AllowUndef) : Res(Res), AllowUndef(AllowUndef) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -239,7 +311,8 @@ } template struct constantint_match { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -262,7 +335,8 @@ /// satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); if (V->getType()->isVectorTy()) { @@ -299,7 +373,8 @@ api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -321,7 +396,8 @@ /// constants that satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); if (V->getType()->isVectorTy()) { @@ -456,7 +532,8 @@ } struct is_zero { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { auto *C = dyn_cast(V); return C && (C->isNullValue() || cst_pred_ty().match(C)); } @@ -613,8 +690,11 @@ bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CV = dyn_cast(V)) { + if (!MContext.acceptBoundNode(V)) return false; + VR = CV; return true; } @@ -656,7 +736,8 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { return V == Val; } }; /// Match if we have a specific specified value. @@ -669,7 +750,8 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *const V, MatchContext & MContext) { return V == Val; } }; /// A commutative-friendly version of m_Specific(). @@ -685,7 +767,8 @@ specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -708,7 +791,8 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -725,7 +809,8 @@ specific_intval(APInt V) : Val(std::move(V)) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -755,7 +840,8 @@ specific_bbval(BasicBlock *Val) : Val(Val) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -787,11 +873,16 @@ // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; return false; } }; @@ -835,12 +926,15 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + MatchContext LRContext(MContext); + if (!MContext.acceptInnerNode(I)) return false; + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; + return false; } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -879,25 +973,26 @@ Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { auto *FPMO = dyn_cast(V); if (!FPMO) return false; - if (FPMO->getOpcode() == Instruction::FNeg) + if (match_cast(V)->getOpcode() == Instruction::FNeg) return X.match(FPMO->getOperand(0)); - if (FPMO->getOpcode() == Instruction::FSub) { + if (match_cast(V)->getOpcode() == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match_context(FPMO->getOperand(1), MContext); } return false; @@ -1011,7 +1106,8 @@ OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *Op = dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; @@ -1021,7 +1117,7 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match_context(Op->getOperand(0), MContext) && R.match_context(Op->getOperand(1), MContext); } return false; } @@ -1103,10 +1199,11 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && L.match_context(I->getOperand(0), MContext) && + R.match_context(I->getOperand(1), MContext); if (auto *CE = dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); @@ -1198,9 +1295,10 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + return PEO->isExact() && SubPattern.match_context(V, MContext); return false; } }; @@ -1225,13 +1323,27 @@ CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) { + if (!MContext.acceptInnerNode(I)) + return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && + R.match_context(I->getOperand(1), LRContext) && + MContext.mergeContext(LRContext)) { Predicate = I->getPredicate(); return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { + } + + if (!Commutable) + return false; + + MatchContext RLContext(MContext); + if (L.match_context(I->getOperand(1), RLContext) && + R.match_context(I->getOperand(0), RLContext) && + MContext.mergeContext(RLContext)) { Predicate = I->getSwappedPredicate(); return true; } @@ -1268,10 +1380,11 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext); } return false; } @@ -1284,10 +1397,12 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext); } return false; } @@ -1303,11 +1418,13 @@ ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext) && + Op3.match_context(I->getOperand(2), MContext); } return false; } @@ -1381,9 +1498,10 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto O = match_dyn_cast(V)) + return O->getOpcode() == Opcode && MContext.acceptInnerNode(O) && Op.match_context(O->getOperand(0), MContext); return false; } }; @@ -1485,8 +1603,9 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1506,10 +1625,12 @@ brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition())) { + return T.match_context(BI->getSuccessor(0), MContext) && F.match_context(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1541,13 +1662,14 @@ // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = match_dyn_cast(V); + if (!SI || !MContext.acceptInnerNode(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = match_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.acceptInnerNode(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1563,9 +1685,12 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatchContext LRContext(MContext); + if (L.match_context(LHS, LRContext) && R.match_context(RHS, LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(RHS, MContext) && R.match_context(LHS, MContext))) return true; + return false; } }; @@ -1723,7 +1848,8 @@ UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) @@ -1789,9 +1915,10 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) + if (const auto *CI = match_dyn_cast(V)) return Val.match(CI->getArgOperand(OpI)); return false; } @@ -1809,8 +1936,9 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (const auto *CI = match_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -2035,7 +2163,8 @@ Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; @@ -2075,10 +2204,11 @@ Opnd_t Val; ExtractValue_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *I = dyn_cast(V)) return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && - Val.match(I->getAggregateOperand()); + Val.match_context(I->getAggregateOperand(), MContext); return false; } }; @@ -2106,7 +2236,8 @@ const DataLayout &DL; VScaleVal_match(const DataLayout &DL) : DL(DL) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(ITy *V, MatchContext & MContext) { if (m_Intrinsic().match(V)) return true; diff --git a/llvm/include/llvm/IR/PredicatedInst.h b/llvm/include/llvm/IR/PredicatedInst.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/PredicatedInst.h @@ -0,0 +1,438 @@ +//===-- llvm/PredicatedInst.h - Predication utility subclass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with predicated instructions. +// Predicated instructions are either regular instructions or calls to +// Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MatcherCast.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void *operator new(size_t s) = delete; + + Value *getMaskParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); + } + + Value *getVectorLengthParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); + } + + /// \returns True if the passed vector length value has no predicating effect + /// on the op. + bool canIgnoreVectorLengthParam() const; + + /// \return True if the static operator of this instruction has a mask or + /// vector length parameter. + bool isVectorPredicatedOp() const { return isa(this); } + + /// \returns the effective Opcode of this operation (ignoring the mask and + /// vector length param). + unsigned getOpcode() const { + auto *VPInst = dyn_cast(this); + + if (!VPInst) { + return cast(this)->getOpcode(); + } + + return VPInst->getFunctionalOpcode(); + } + + bool isVectorReduction() const; + + static bool classof(const Instruction *I) { return isa(I); } + static bool classof(const ConstantExpr *CE) { return false; } + static bool classof(const Value *V) { return isa(V); } + + /// Convenience function for getting all the fast-math flags, which must be an + /// operator which supports these flags. See LangRef.html for the meaning of + /// these flags. + FastMathFlags getFastMathFlags() const; +}; + +class PredicatedOperator : public User { +public: + // The PredicatedOperator class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedOperator() = delete; + ~PredicatedOperator() = delete; + + void *operator new(size_t s) = delete; + + /// Return the opcode for this Instruction or ConstantExpr. + unsigned getOpcode() const { + auto *VPInst = dyn_cast(this); + + // Conceal the fp operation if it has non-default rounding mode or exception + // behavior + if (VPInst && !VPInst->isConstrainedOp()) { + return VPInst->getFunctionalOpcode(); + } + + if (const Instruction *I = dyn_cast(this)) + return I->getOpcode(); + + return cast(this)->getOpcode(); + } + + Value *getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); + } + + Value *getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); + } + + void copyIRFlags(const Value *V, bool IncludeWrapFlags = true); + FastMathFlags getFastMathFlags() const { + auto *I = dyn_cast(this); + if (I) + return I->getFastMathFlags(); + else + return FastMathFlags(); + } + + static bool classof(const Instruction *I) { + return isa(I) || isa(I); + } + static bool classof(const ConstantExpr *CE) { return isa(CE); } + static bool classof(const Value *V) { + return isa(V) || isa(V); + } +}; + +class PredicatedBinaryOperator : public PredicatedOperator { +public: + // The PredicatedBinaryOperator class is intended to be used as a utility, and + // is never itself instantiated. + PredicatedBinaryOperator() = delete; + ~PredicatedBinaryOperator() = delete; + + using BinaryOps = Instruction::BinaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->isBinaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore); + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, + InsertBefore); + } + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name, BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, InsertAtEnd, + nullptr); + } + + static Instruction *CreateWithCopiedFlags(Module *Mod, Value *Mask, + Value *VectorLen, BinaryOps Opc, + Value *V1, Value *V2, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = + Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedICmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedICmpInst() = delete; + ~PredicatedICmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::ICmp; + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::ICmp; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + ICmpInst::Predicate getPredicate() const { + auto *ICInst = dyn_cast(this); + if (ICInst) + return ICInst->getPredicate(); + auto *CE = dyn_cast(this); + if (CE) + return static_cast(CE->getPredicate()); + return static_cast( + cast(this)->getCmpPredicate()); + } +}; + +class PredicatedFCmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedFCmpInst() = delete; + ~PredicatedFCmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::FCmp; + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::FCmp; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + return isa(V); + } + + FCmpInst::Predicate getPredicate() const { + auto *FCInst = dyn_cast(this); + if (FCInst) + return FCInst->getPredicate(); + auto *CE = dyn_cast(this); + if (CE) + return static_cast(CE->getPredicate()); + return static_cast( + cast(this)->getCmpPredicate()); + } +}; + +class PredicatedSelectInst : public PredicatedOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedSelectInst() = delete; + ~PredicatedSelectInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::Select; + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::Select; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && CE->getOpcode() == Instruction::Select; + } + + const Value *getCondition() const { return getOperand(0); } + const Value *getTrueValue() const { return getOperand(1); } + const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *V) { setOperand(0, V); } + void setTrueValue(Value *V) { setOperand(1, V); } + void setFalseValue(Value *V) { setOperand(2, V); } +}; + +namespace PatternMatch { + +// PredicatedMatchContext for pattern matching +struct PredicatedContext { + Value *Mask; + Value *VectorLength; + Module *Mod; + + void reset(Value *V) { + auto *PI = dyn_cast(V); + if (!PI) { + VectorLength = nullptr; + Mask = nullptr; + return; + } + VectorLength = PI->getVectorLengthParam(); + Mask = PI->getMaskParam(); + + if (Mod) return; + + // try to get a hold of the Module + auto *BB = PI->getParent(); + if (BB) { + auto *Func = BB->getParent(); + if (Func) { + Mod = Func->getParent(); + } + } + + if (Mod) return; + + // try to infer the module from a call + auto CallI = dyn_cast(V); + if (CallI && CallI->getCalledFunction()) { + Mod = CallI->getCalledFunction()->getParent(); + } + } + + PredicatedContext(Value *Val) + : Mask(nullptr), VectorLength(nullptr), Mod(nullptr) { + reset(Val); + } + + PredicatedContext(const PredicatedContext &PC) + : Mask(PC.Mask), VectorLength(PC.VectorLength), Mod(PC.Mod) {} + + /// accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value *Val) const { + auto PredI = dyn_cast(Val); + if (!PredI) + return VectorLength == nullptr && Mask == nullptr; + return VectorLength == PredI->getVectorLengthParam() && + Mask == PredI->getMaskParam(); + } + + /// accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value *Val) const { return true; } + + /// whether this context is compatiable with \p E. + bool acceptContext(PredicatedContext PC) const { + return std::tie(PC.Mask, PC.VectorLength) == std::tie(Mask, VectorLength); + } + + /// merge the context \p E into this context and return whether the resulting + /// context is valid. + bool mergeContext(PredicatedContext PC) const { return acceptContext(PC); } + + /// match \p P in a new contesx for \p Val. + template + bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + /// match \p P in the current context. + template + bool try_match(Val *V, const Pattern &P) { + PredicatedContext SubContext(*this); + return const_cast(P).match_context(V, SubContext); + } +}; + +struct PredicatedContext; +template <> struct MatcherCast { + using ActualCastType = PredicatedBinaryOperator; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedOperator; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedICmpInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedFCmpInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedSelectInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedInstruction; +}; + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H diff --git a/llvm/include/llvm/IR/VPBuilder.h b/llvm/include/llvm/IR/VPBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/VPBuilder.h @@ -0,0 +1,184 @@ +#ifndef LLVM_IR_VPBUILDER_H +#define LLVM_IR_VPBUILDER_H + +#include +#include +#include +#include +#include +#include + +namespace llvm { + +using ValArray = ArrayRef; + +class VPBuilder { + IRBuilder<> & Builder; + + // Explicit mask parameter + Value * Mask; + // Explicit vector length parameter + Value * ExplicitVectorLength; + // Compile-time vector length + int StaticVectorLength; + + // get a valid mask/evl argument for the current predication contet + Value& RequestPred(); + Value& RequestEVL(); + +public: + VPBuilder(IRBuilder<> & _builder) + : Builder(_builder) + , Mask(nullptr) + , ExplicitVectorLength(nullptr) + , StaticVectorLength(-1) + {} + + Module & getModule() const; + LLVMContext & getContext() const { return Builder.getContext(); } + + // The cannonical vector type for this \p ElementTy + VectorType& getVectorType(Type &ElementTy); + + // Predication context tracker + VPBuilder& setMask(Value * _Mask) { Mask = _Mask; return *this; } + VPBuilder& setEVL(Value * _ExplicitVectorLength) { ExplicitVectorLength = _ExplicitVectorLength; return *this; } + VPBuilder& setStaticVL(int VLen) { StaticVectorLength = VLen; return *this; } + + // Create a map-vectorized copy of the instruction \p Inst with the underlying IRBuilder instance. + // This operation may return nullptr if the instruction could not be vectorized. + Value* CreateVectorCopy(Instruction & Inst, ValArray VecOpArray); + + // shift the elements in \p SrcVal by Amount where the result lane is true. + Value* CreateVectorShift(Value *SrcVal, Value *Amount, Twine Name=""); + + // Memory + Value& CreateContiguousStore(Value & Val, Value & Pointer, MaybeAlign Alignment); + Value& CreateContiguousLoad(Value & Pointer, MaybeAlign Alignment); + Value& CreateScatter(Value & Val, Value & PointerVec, MaybeAlign Alignment); + Value& CreateGather(Value & PointerVec, MaybeAlign Alignment); +}; + + + + + +namespace PatternMatch { + // Factory class to generate instructions in a context + template + class MatchContextBuilder { + public: + // MatchContextBuilder(MatcherContext MC); + }; + + +// Context-free instruction builder +template<> +class MatchContextBuilder { +public: + MatchContextBuilder(EmptyContext & EC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const { \ + auto * Inst = BinaryOperator::Create(Instruction::OPC, V1, V2, Name); \ + Builder.Insert(Inst); return Inst; \ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, BB);\ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, I);\ + } \ + Instruction *Create##OPC##FMF(Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + return BinaryOperator::CreateWithCopiedFlags(Instruction::OPC, V1, V2, FMFSource, Name);\ + } \ + template \ + Instruction *Create##OPC##FMF(IRBuilderType& Builder, Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + auto * Inst = BinaryOperator::CreateWithCopiedFlags(Instruction::OPC, V1, V2, FMFSource, Name);\ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + UnaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + return UnaryOperator::CreateFNegFMF(Op, FMFSource, Name); + } + + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + + + +// Context-free instruction builder +template<> +class MatchContextBuilder { + PredicatedContext & PC; +public: + MatchContextBuilder(PredicatedContext & PC) : PC(PC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + auto * PredInst = Create##OPC(V1, V2, Name); \ + Builder.Insert(PredInst); \ + return PredInst; \ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, BB);\ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, I);\ + } \ + Instruction *Create##OPC##FMF(Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, FMFSource, Name);\ + } \ + template \ + Instruction *Create##OPC##FMF(IRBuilderType& Builder, Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + auto * Inst = PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, FMFSource, Name);\ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + Instruction *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + // FIXME use llvm.vp.fneg + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, Zero, Op, FMFSource, Name); + } + + // TODO predicated casts + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + +} + +} // namespace llvm + +#endif // LLVM_IR_VPBUILDER_H diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/VPIntrinsics.def @@ -0,0 +1,619 @@ +//===-- IR/VPIntrinsics.def - Describes llvm.vp.* Intrinsics -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains descriptions of the various Vector Predication intrinsics. +// This is used as a central place for enumerating the different instructions +// and should eventually be the place to put comments about the instructions. +// +//===----------------------------------------------------------------------===// + +// NOTE: NO INCLUDE GUARD DESIRED! + +// Provide definitions of macros so that users of this file do not have to +// define everything to use it... +// +#ifndef REGISTER_VP_INTRINSIC +#define REGISTER_VP_INTRINSIC(VPID, MASKPOS, VLENPOS) +#endif + +// This is a reduction intrinsic with accumulator arg at ACCUPOS, reduced vector +// arg at VECTORPOS. +#ifndef HANDLE_VP_REDUCTION +#define HANDLE_VP_REDUCTION(VPID, ACCUPOS, VECTORPOS) +#endif + +// The intrinsic VPID of llvm.vp.* functionally corresponds to the intrinsic +// CFPID of llvm.experimental.constrained.*. +#ifndef HANDLE_VP_TO_CONSTRAINED_INTRIN +#define HANDLE_VP_TO_CONSTRAINED_INTRIN(VPID, CPFID) +#endif + +// This VP intrinsic has constraint fp params. +// Rounding mode arg pos is ROUNDPOS, exception behavior arg pos is EXCEPT POS. +#ifndef HANDLE_VP_FPCONSTRAINT +#define HANDLE_VP_FPCONSTRAINT(VPID, ROUNDPOS, EXCEPTPOS) +#endif + +// Map this VP intrinsic to its functional Opcode +#ifndef HANDLE_VP_TO_OC +#define HANDLE_VP_TO_OC(VPID, OC) +#endif + +// Map this VP intrinsic to its cannonical functional intrinsic. +#ifndef HANDLE_VP_TO_INTRIN +#define HANDLE_VP_TO_INTRIN(VPID, ID) +#endif + +// This VP Intrinsic is a unary operator +// (only count data params) +#ifndef HANDLE_VP_IS_UNARY +#define HANDLE_VP_IS_UNARY(VPID) +#endif + +// This VP Intrinsic is a binary operator +// (only count data params) +#ifndef HANDLE_VP_IS_BINARY +#define HANDLE_VP_IS_BINARY(VPID) +#endif + +// This VP Intrinsic is a ternary operator +// (only count data params) +#ifndef HANDLE_VP_IS_TERNARY +#define HANDLE_VP_IS_TERNARY(VPID) +#endif + +// This VP Intrinsic is a comparison +// (only count data params) +#ifndef HANDLE_VP_IS_XCMP +#define HANDLE_VP_IS_XCMP(VPID) +#endif + +// This VP Intrinsic is a memory operation +// The pointer arg is at POINTERPOS and the data arg is at DATAPOS. +#ifndef HANDLE_VP_IS_MEMOP +#define HANDLE_VP_IS_MEMOP(VPID, POINTERPOS, DATAPOS) +#endif + +#ifndef REGISTER_VP_SDNODE +#define REGISTER_VP_SDNODE(NODEID, MASKPOS, VLENPOS, NAME) +#endif + +/// This VP Intrinsic lowers to this VP SDNode. +#ifndef HANDLE_VP_TO_SDNODE +#define HANDLE_VP_TO_SDNODE(VPID,NODEID) +#endif + +///// Integer Arithmetic ///// + +// llvm.vp.add(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_add, 2, 3) +HANDLE_VP_TO_OC(vp_add, Add) +HANDLE_VP_IS_BINARY(vp_add) +REGISTER_VP_SDNODE(VP_ADD,"vp_add", 2, 3) +HANDLE_VP_TO_SDNODE(vp_add,VP_ADD) + +// llvm.vp.and(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_and, 2, 3) +HANDLE_VP_TO_OC(vp_and, And) +HANDLE_VP_IS_BINARY(vp_and) +REGISTER_VP_SDNODE(VP_AND,"vp_and", 2, 3) +HANDLE_VP_TO_SDNODE(vp_and,VP_AND) + +// llvm.vp.ashr(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_ashr, 2, 3) +HANDLE_VP_TO_OC(vp_ashr, AShr) +HANDLE_VP_IS_BINARY(vp_ashr) +REGISTER_VP_SDNODE(VP_SRA,"vp_sra", 2, 3) +HANDLE_VP_TO_SDNODE(vp_ashr,VP_SRA) + +// llvm.vp.lshr(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_lshr, 2, 3) +HANDLE_VP_TO_OC(vp_lshr, LShr) +HANDLE_VP_IS_BINARY(vp_lshr) +REGISTER_VP_SDNODE(VP_SRL,"vp_srl", 2, 3) +HANDLE_VP_TO_SDNODE(vp_lshr,VP_SRL) + +// llvm.vp.mul(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_mul, 2, 3) +HANDLE_VP_TO_OC(vp_mul, Mul) +HANDLE_VP_IS_BINARY(vp_mul) +REGISTER_VP_SDNODE(VP_MUL,"vp_mul", 2, 3) +HANDLE_VP_TO_SDNODE(vp_mul,VP_MUL) + +// llvm.vp.or(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_or, 2, 3) +HANDLE_VP_TO_OC(vp_or, Or) +HANDLE_VP_IS_BINARY(vp_or) +REGISTER_VP_SDNODE(VP_OR,"vp_or", 2, 3) +HANDLE_VP_TO_SDNODE(vp_or,VP_OR) + +// llvm.vp.sdiv(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_sdiv, 2, 3) +HANDLE_VP_TO_OC(vp_sdiv, SDiv) +HANDLE_VP_IS_BINARY(vp_sdiv) +REGISTER_VP_SDNODE(VP_SDIV,"vp_sdiv", 2, 3) +HANDLE_VP_TO_SDNODE(vp_sdiv,VP_SDIV) + +// llvm.vp.shl(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_shl, 2, 3) +HANDLE_VP_TO_OC(vp_shl, Shl) +HANDLE_VP_IS_BINARY(vp_shl) +REGISTER_VP_SDNODE(VP_SHL,"vp_shl", 2, 3) +HANDLE_VP_TO_SDNODE(vp_shl,VP_SHL) + +// llvm.vp.srem(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_srem, 2, 3) +HANDLE_VP_TO_OC(vp_srem, SRem) +HANDLE_VP_IS_BINARY(vp_srem) +REGISTER_VP_SDNODE(VP_SREM,"vp_srem", 2, 3) +HANDLE_VP_TO_SDNODE(vp_srem,VP_SREM) + +// llvm.vp.sub(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_sub, 2, 3) +HANDLE_VP_TO_OC(vp_sub, Sub) +HANDLE_VP_IS_BINARY(vp_sub) +REGISTER_VP_SDNODE(VP_SUB,"vp_sub", 2, 3) +HANDLE_VP_TO_SDNODE(vp_sub,VP_SUB) + +// llvm.vp.udiv(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_udiv, 2, 3) +HANDLE_VP_TO_OC(vp_udiv, UDiv) +HANDLE_VP_IS_BINARY(vp_udiv) +REGISTER_VP_SDNODE(VP_UDIV,"vp_udiv", 2, 3) +HANDLE_VP_TO_SDNODE(vp_udiv,VP_UDIV) + +// llvm.vp.urem(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_urem, 2, 3) +HANDLE_VP_TO_OC(vp_urem, URem) +HANDLE_VP_IS_BINARY(vp_urem) +REGISTER_VP_SDNODE(VP_UREM,"vp_urem", 2, 3) +HANDLE_VP_TO_SDNODE(vp_urem,VP_UREM) + +// llvm.vp.xor(x,y,mask,vlen) +REGISTER_VP_INTRINSIC(vp_xor, 2, 3) +HANDLE_VP_TO_OC(vp_xor, Xor) +HANDLE_VP_IS_BINARY(vp_xor) +REGISTER_VP_SDNODE(VP_XOR,"vp_xor", 2, 3) +HANDLE_VP_TO_SDNODE(vp_xor,VP_XOR) + +///// FP Arithmetic ///// + +// llvm.vp.fadd(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fadd, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_fadd, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fadd, experimental_constrained_fadd) +HANDLE_VP_TO_OC(vp_fadd, FAdd) +HANDLE_VP_IS_BINARY(vp_fadd) +REGISTER_VP_SDNODE(VP_FADD,"vp_fadd", 4, 5) +HANDLE_VP_TO_SDNODE(vp_fadd,VP_FADD) + +// llvm.vp.fdiv(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fdiv, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_fdiv, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fdiv, experimental_constrained_fdiv) +HANDLE_VP_TO_OC(vp_fdiv, FDiv) +HANDLE_VP_IS_BINARY(vp_fdiv) +REGISTER_VP_SDNODE(VP_FDIV,"vp_fdiv", 4, 5) +HANDLE_VP_TO_SDNODE(vp_fdiv,VP_FDIV) + +// llvm.vp.fmul(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fmul, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_fmul, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fmul, experimental_constrained_fmul) +HANDLE_VP_TO_OC(vp_fmul, FMul) +HANDLE_VP_IS_BINARY(vp_fmul) +REGISTER_VP_SDNODE(VP_FMUL,"vp_fmul", 4, 5) +HANDLE_VP_TO_SDNODE(vp_fmul,VP_FMUL) + +// llvm.vp.fneg(x,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fneg, 2, 3) +HANDLE_VP_FPCONSTRAINT(vp_fneg, None, 1) +HANDLE_VP_TO_OC(vp_fneg, FNeg) +HANDLE_VP_IS_UNARY(vp_fneg) +REGISTER_VP_SDNODE(VP_FNEG, "vp_fneg", 2, 3) +HANDLE_VP_TO_SDNODE(vp_fneg,VP_FNEG) + +// llvm.vp.frem(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_frem, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_frem, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_frem, experimental_constrained_frem) +HANDLE_VP_TO_OC(vp_frem, FRem) +HANDLE_VP_IS_BINARY(vp_frem) +REGISTER_VP_SDNODE(VP_FREM, "vp_frem", 4, 5) +HANDLE_VP_TO_SDNODE(vp_frem,VP_FREM) + +// llvm.vp.fsub(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fsub, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_fsub, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fsub, experimental_constrained_fsub) +HANDLE_VP_TO_OC(vp_fsub, FSub) +HANDLE_VP_IS_BINARY(vp_fsub) +REGISTER_VP_SDNODE(VP_FSUB, "vp_fsub", 4, 5) +HANDLE_VP_TO_SDNODE(vp_fsub,VP_FSUB) + +// llvm.vp.fma(x,y,z.round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fma, 5, 6) +HANDLE_VP_FPCONSTRAINT(vp_fma, 3, 4) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fma, experimental_constrained_fma) +HANDLE_VP_TO_INTRIN(vp_fma, fma) +HANDLE_VP_IS_TERNARY(vp_fma) +REGISTER_VP_SDNODE(VP_FMA, "vp_fma", 5, 6) +HANDLE_VP_TO_SDNODE(vp_fma,VP_FMA) + +///// Cast, Extend & Round ///// + +// llvm.vp.ceil(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_ceil, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_ceil, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_ceil, experimental_constrained_ceil) +HANDLE_VP_TO_INTRIN(vp_ceil, ceil) +REGISTER_VP_SDNODE(VP_FCEIL, "vp_fceil", 3, 4) +HANDLE_VP_TO_SDNODE(vp_ceil,VP_FCEIL) + +// llvm.vp.trunc(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_trunc, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_trunc, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_trunc, experimental_constrained_trunc) +HANDLE_VP_TO_INTRIN(vp_trunc, trunc) +REGISTER_VP_SDNODE(VP_FTRUNC, "vp_ftrunc", 3, 4) +HANDLE_VP_TO_SDNODE(vp_trunc,VP_FTRUNC) + +// llvm.vp.floor(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_floor, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_floor, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_floor, experimental_constrained_floor) +HANDLE_VP_TO_INTRIN(vp_floor, floor) +REGISTER_VP_SDNODE(VP_FFLOOR, "vp_ffloor", 3, 4) +HANDLE_VP_TO_SDNODE(vp_floor,VP_FFLOOR) + +// llvm.vp.fpext(x,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fpext, 2, 3) +HANDLE_VP_FPCONSTRAINT(vp_fpext, None, 1) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fpext, experimental_constrained_fpext) +HANDLE_VP_TO_OC(vp_fpext, FPExt) +REGISTER_VP_SDNODE(VP_FP_EXTEND, "vp_fpext", 2, 3) +HANDLE_VP_TO_SDNODE(vp_fpext,VP_FP_EXTEND) + +// llvm.vp.fptrunc(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fptrunc, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_fptrunc, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fptrunc, experimental_constrained_fptrunc) +HANDLE_VP_TO_OC(vp_fptrunc, FPTrunc) +HANDLE_VP_TO_SDNODE(vp_fptrunc,VP_FTRUNC) + +// llvm.vp.fptoui(x,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fptoui, 2, 3) +HANDLE_VP_FPCONSTRAINT(vp_fptoui, None, 1) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fptoui, experimental_constrained_fptoui) +HANDLE_VP_TO_OC(vp_fptoui, FPToUI) +REGISTER_VP_SDNODE(VP_FP_TO_UINT, "vp_fp_to_uint", 2, 3) +HANDLE_VP_TO_SDNODE(vp_fptoui,VP_FP_TO_UINT) + +// llvm.vp.fptosi(x,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fptosi, 2, 3) +HANDLE_VP_FPCONSTRAINT(vp_fptosi, None, 1) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_fptosi, experimental_constrained_fptosi) +HANDLE_VP_TO_OC(vp_fptosi, FPToSI) +REGISTER_VP_SDNODE(VP_FP_TO_SINT, "vp_fp_to_sint", 2, 3) +HANDLE_VP_TO_SDNODE(vp_fptosi,VP_FP_TO_SINT) + +// llvm.vp.uitofp(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_uitofp, 3, 4) +HANDLE_VP_TO_OC(vp_uitofp, UIToFP) +HANDLE_VP_FPCONSTRAINT(vp_uitofp, 1, 2) +REGISTER_VP_SDNODE(VP_UINT_TO_FP, "vp_uint_to_fp", 3, 4) +HANDLE_VP_TO_SDNODE(vp_uitofp,VP_UINT_TO_FP) +// HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_uitofp,experimental_constrained_uitofp) + +// llvm.vp.sitofp(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_sitofp, 3, 4) +HANDLE_VP_TO_OC(vp_sitofp, SIToFP) +HANDLE_VP_FPCONSTRAINT(vp_sitofp, 1, 2) +REGISTER_VP_SDNODE(VP_SINT_TO_FP, "vp_sint_to_fp", 3, 4) +HANDLE_VP_TO_SDNODE(vp_sitofp,VP_SINT_TO_FP) +// HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_sitofp,experimental_constrained_sitofp) + +// llvm.vp.round(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_round, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_round, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_round, experimental_constrained_round) +HANDLE_VP_TO_INTRIN(vp_round, round) +REGISTER_VP_SDNODE(VP_FROUND, "vp_fround", 3, 4) +HANDLE_VP_TO_SDNODE(vp_round,VP_FROUND) + +// llvm.vp.rint(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_rint, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_rint, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_rint, experimental_constrained_rint) +HANDLE_VP_TO_INTRIN(vp_rint, rint) +REGISTER_VP_SDNODE(VP_FRINT, "vp_frint", 3, 4) +HANDLE_VP_TO_SDNODE(vp_rint,VP_FRINT) + +// llvm.vp.nearbyint(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_nearbyint, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_nearbyint, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_nearbyint, + experimental_constrained_nearbyint) +HANDLE_VP_TO_INTRIN(vp_nearbyint, nearbyint) +REGISTER_VP_SDNODE(VP_FNEARBYINT, "vp_fnearbyint", 3, 4) +HANDLE_VP_TO_SDNODE(vp_nearbyint,VP_FNEARBYINT) + +///// Math Funcs ///// + +// llvm.vp.sqrt(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_sqrt, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_sqrt, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_sqrt, experimental_constrained_sqrt) +HANDLE_VP_TO_INTRIN(vp_sqrt, sqrt) +REGISTER_VP_SDNODE(VP_FSQRT, "vp_fsqrt", 3, 4) +HANDLE_VP_TO_SDNODE(vp_sqrt,VP_FSQRT) + +// llvm.vp.pow(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_pow, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_pow, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_pow, experimental_constrained_pow) +HANDLE_VP_TO_INTRIN(vp_pow, pow) +REGISTER_VP_SDNODE(VP_FPOW, "vp_fpow", 4, 5) +HANDLE_VP_TO_SDNODE(vp_pow,VP_FPOW) + +// llvm.vp.powi(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_powi, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_powi, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_powi, experimental_constrained_powi) +HANDLE_VP_TO_INTRIN(vp_powi, powi) +REGISTER_VP_SDNODE(VP_FPOWI, "vp_fpowi", 4, 5) +HANDLE_VP_TO_SDNODE(vp_powi,VP_FPOWI) + +// llvm.vp.maxnum(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_maxnum, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_maxnum, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_maxnum, experimental_constrained_maxnum) +HANDLE_VP_TO_INTRIN(vp_maxnum, maxnum) +REGISTER_VP_SDNODE(VP_FMAXNUM, "vp_fmaxnum", 4, 5) +HANDLE_VP_TO_SDNODE(vp_maxnum,VP_FMAXNUM) + +// llvm.vp.minnum(x,y,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_minnum, 4, 5) +HANDLE_VP_FPCONSTRAINT(vp_minnum, 2, 3) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_minnum, experimental_constrained_minnum) +HANDLE_VP_TO_INTRIN(vp_minnum, minnum) +REGISTER_VP_SDNODE(VP_FMINNUM, "vp_fminnum", 4, 5) +HANDLE_VP_TO_SDNODE(vp_minnum,VP_FMINNUM) + +// llvm.vp.sin(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_sin, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_sin, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_sin, experimental_constrained_sin) +HANDLE_VP_TO_INTRIN(vp_sin, sin) +REGISTER_VP_SDNODE(VP_FSIN, "vp_fsin", 3, 4) +HANDLE_VP_TO_SDNODE(vp_sin,VP_FSIN) + +// llvm.vp.cos(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_cos, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_cos, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_cos, experimental_constrained_cos) +HANDLE_VP_TO_INTRIN(vp_cos, cos) +REGISTER_VP_SDNODE(VP_FCOS, "vp_fcos", 3, 4) +HANDLE_VP_TO_SDNODE(vp_cos,VP_FCOS) + +// llvm.vp.log(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_log, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_log, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_log, experimental_constrained_log) +HANDLE_VP_TO_INTRIN(vp_log, log) +REGISTER_VP_SDNODE(VP_FLOG, "vp_flog", 3, 4) +HANDLE_VP_TO_SDNODE(vp_log,VP_FLOG) + +// llvm.vp.log10(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_log10, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_log10, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_log10, experimental_constrained_log10) +HANDLE_VP_TO_INTRIN(vp_log10, log10) +REGISTER_VP_SDNODE(VP_FLOG10, "vp_flog10", 3, 4) +HANDLE_VP_TO_SDNODE(vp_log10,VP_FLOG10) + +// llvm.vp.log2(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_log2, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_log2, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_log2, experimental_constrained_log2) +HANDLE_VP_TO_INTRIN(vp_log2, log2) +REGISTER_VP_SDNODE(VP_FLOG2, "vp_flog2", 3, 4) +HANDLE_VP_TO_SDNODE(vp_log2,VP_FLOG2) + +// llvm.vp.exp(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_exp, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_exp, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_exp, experimental_constrained_exp) +HANDLE_VP_TO_INTRIN(vp_exp, exp) +REGISTER_VP_SDNODE(VP_FEXP, "vp_fexp", 3, 4) +HANDLE_VP_TO_SDNODE(vp_exp,VP_FEXP) + +// llvm.vp.exp2(x,round,except,mask,vlen) +REGISTER_VP_INTRINSIC(vp_exp2, 3, 4) +HANDLE_VP_FPCONSTRAINT(vp_exp2, 1, 2) +HANDLE_VP_TO_CONSTRAINED_INTRIN(vp_exp2, experimental_constrained_exp2) +HANDLE_VP_TO_INTRIN(vp_exp2, exp2) +REGISTER_VP_SDNODE(VP_FEXP2, "vp_fexp2", 3, 4) +HANDLE_VP_TO_SDNODE(vp_exp2,VP_FEXP2) + +///// Comparison ///// + +// llvm.vp.fcmp(x,y,pred,mask,vlen) +REGISTER_VP_INTRINSIC(vp_fcmp, 3, 4) +HANDLE_VP_TO_OC(vp_fcmp, FCmp) +HANDLE_VP_IS_XCMP(vp_fcmp) + +// llvm.vp.icmp(x,y,cmp_pred,mask,vlen) +REGISTER_VP_INTRINSIC(vp_icmp, 3, 4) +HANDLE_VP_TO_OC(vp_icmp, ICmp) +HANDLE_VP_IS_XCMP(vp_icmp) + +REGISTER_VP_SDNODE(VP_SETCC, "vp_setcc", 3, 4) + +///// Memory Operations ///// + +// llvm.vp.store(ptr,val,mask,vlen) +REGISTER_VP_INTRINSIC(vp_store, 2, 3) +HANDLE_VP_TO_OC(vp_store, Store) +HANDLE_VP_TO_INTRIN(vp_store, masked_store) +HANDLE_VP_IS_MEMOP(vp_store, 1, 0) +REGISTER_VP_SDNODE(VP_STORE, "vp_store", 3, 4) + +// llvm.vp.scatter(ptr,val,mask,vlen) +REGISTER_VP_INTRINSIC(vp_scatter, 2, 3) +HANDLE_VP_TO_INTRIN(vp_scatter, masked_scatter) +HANDLE_VP_IS_MEMOP(vp_scatter, 1, 0) +REGISTER_VP_SDNODE(VP_SCATTER, "vp_scatter", 3, 4) + +// llvm.vp.load(ptr,mask,vlen) +REGISTER_VP_INTRINSIC(vp_load, 1, 2) +HANDLE_VP_TO_OC(vp_load, Load) +HANDLE_VP_TO_INTRIN(vp_load, masked_load) +HANDLE_VP_IS_MEMOP(vp_load, 0, None) +REGISTER_VP_SDNODE(VP_LOAD, "vp_load", 2, 3) + +// llvm.vp.gather(ptr,mask,vlen) +REGISTER_VP_INTRINSIC(vp_gather, 1, 2) +HANDLE_VP_TO_INTRIN(vp_gather, masked_gather) +HANDLE_VP_IS_MEMOP(vp_gather, 0, None) +REGISTER_VP_SDNODE(VP_GATHER, "vp_gather", 1, 2) + +///// Shuffle & Blend ///// + +// llvm.vp.compress(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_compress, 1, 2) +REGISTER_VP_SDNODE(VP_COMPRESS, "vp_compress", 1, 2) +HANDLE_VP_TO_SDNODE(vp_compress,VP_COMPRESS) + +// llvm.vp.expand(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_expand, 1, 2) +REGISTER_VP_SDNODE(VP_EXPAND, "vp_expand", 1, 2) +HANDLE_VP_TO_SDNODE(vp_expand,VP_EXPAND) + +// llvm.vp.vshift(x,amount,mask,vlen) +REGISTER_VP_INTRINSIC(vp_vshift, 2, 3) +REGISTER_VP_SDNODE(VP_VSHIFT, "vp_vshift", 2, 3) +HANDLE_VP_TO_SDNODE(vp_vshift,VP_VSHIFT) + +// llvm.vp.select(mask,on_true,on_false,vlen) +REGISTER_VP_INTRINSIC(vp_select, 0, 3) +HANDLE_VP_TO_OC(vp_select, Select) +REGISTER_VP_SDNODE(VP_SELECT, "vp_select", 0, 3) +HANDLE_VP_TO_SDNODE(vp_select,VP_SELECT) + +// llvm.vp.compose(x,y,pivot,vlen) +REGISTER_VP_INTRINSIC(vp_compose, None, 3) +REGISTER_VP_SDNODE(VP_COMPOSE, "vp_compose", None, 3) +HANDLE_VP_TO_SDNODE(vp_compose,VP_COMPOSE) + +///// Reduction ///// + +// llvm.vp.reduce.add(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_add, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_add, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_add, experimental_vector_reduce_add) +REGISTER_VP_SDNODE(VP_REDUCE_ADD, "vp_reduce_add", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_add,VP_REDUCE_ADD) + +// llvm.vp.reduce.mul(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_mul, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_mul, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_mul, experimental_vector_reduce_mul) +REGISTER_VP_SDNODE(VP_REDUCE_MUL, "vp_reduce_mul", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_mul,VP_REDUCE_MUL) + +// llvm.vp.reduce.and(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_and, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_and, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_and, experimental_vector_reduce_and) +REGISTER_VP_SDNODE(VP_REDUCE_AND, "vp_reduce_and", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_and,VP_REDUCE_AND) + +// llvm.vp.reduce.or(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_or, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_or, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_or, experimental_vector_reduce_or) +REGISTER_VP_SDNODE(VP_REDUCE_OR, "vp_reduce_or", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_or,VP_REDUCE_OR) + +// llvm.vp.reduce.xor(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_xor, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_xor, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_xor, experimental_vector_reduce_xor) +REGISTER_VP_SDNODE(VP_REDUCE_XOR, "vp_reduce_xor", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_xor,VP_REDUCE_XOR) + +// llvm.vp.reduce.smin(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_smin, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_smin, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_smin, experimental_vector_reduce_smin) +REGISTER_VP_SDNODE(VP_REDUCE_SMIN, "vp_reduce_smin", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_smin,VP_REDUCE_SMIN) + +// llvm.vp.reduce.smax(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_smax, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_smax, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_smax, experimental_vector_reduce_smax) +REGISTER_VP_SDNODE(VP_REDUCE_SMAX, "vp_reduce_smax", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_smax,VP_REDUCE_SMAX) + +// llvm.vp.reduce.umin(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_umin, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_umin, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_umin, experimental_vector_reduce_umin) +REGISTER_VP_SDNODE(VP_REDUCE_UMIN, "vp_reduce_umin", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_umin,VP_REDUCE_UMIN) + +// llvm.vp.reduce.umax(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_umax, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_umax, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_umax, experimental_vector_reduce_umax) +REGISTER_VP_SDNODE(VP_REDUCE_UMAX, "vp_reduce_umax", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_umax,VP_REDUCE_UMAX) + +// llvm.vp.reduce.fadd(accu,x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_fadd, 2, 3) +HANDLE_VP_REDUCTION(vp_reduce_fadd, 0, 1) +HANDLE_VP_TO_INTRIN(vp_reduce_fadd, experimental_vector_reduce_v2_fadd) +REGISTER_VP_SDNODE(VP_REDUCE_FADD, "vp_reduce_fadd", 2, 3) +HANDLE_VP_TO_SDNODE(vp_reduce_fadd,VP_REDUCE_FADD) + +// llvm.vp.reduce.fmul(accu,x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_fmul, 2, 3) +HANDLE_VP_REDUCTION(vp_reduce_fmul, 0, 1) +HANDLE_VP_TO_INTRIN(vp_reduce_fmul, experimental_vector_reduce_v2_fmul) +REGISTER_VP_SDNODE(VP_REDUCE_FMUL, "vp_reduce_fmul", 2, 3) +HANDLE_VP_TO_SDNODE(vp_reduce_fmul,VP_REDUCE_FMUL) + +// llvm.vp.reduce.fmin(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_fmin, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_fmin, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_fmin, experimental_vector_reduce_fmin) +REGISTER_VP_SDNODE(VP_REDUCE_FMIN, "vp_reduce_fmin", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_fmin,VP_REDUCE_FMIN) + +// llvm.vp.reduce.fmax(x,mask,vlen) +REGISTER_VP_INTRINSIC(vp_reduce_fmax, 1, 2) +HANDLE_VP_REDUCTION(vp_reduce_fmax, None, 0) +HANDLE_VP_TO_INTRIN(vp_reduce_fmax, experimental_vector_reduce_fmax) +REGISTER_VP_SDNODE(VP_REDUCE_FMAX, "vp_reduce_fmax", 1, 2) +HANDLE_VP_TO_SDNODE(vp_reduce_fmax,VP_REDUCE_FMAX) + +#undef REGISTER_VP_INTRINSIC +#undef REGISTER_VP_SDNODE +#undef HANDLE_VP_IS_UNARY +#undef HANDLE_VP_IS_BINARY +#undef HANDLE_VP_IS_TERNARY +#undef HANDLE_VP_IS_XCMP +#undef HANDLE_VP_IS_MEMOP +#undef HANDLE_VP_TO_OC +#undef HANDLE_VP_TO_CONSTRAINED_INTRIN +#undef HANDLE_VP_TO_INTRIN +#undef HANDLE_VP_FPCONSTRAINT +#undef HANDLE_VP_REDUCTION +#undef HANDLE_VP_TO_SDNODE diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -145,6 +145,7 @@ void initializeExpandMemCmpPassPass(PassRegistry&); void initializeExpandPostRAPass(PassRegistry&); void initializeExpandReductionsPass(PassRegistry&); +void initializeExpandVectorPredicationPass(PassRegistry&); void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&); void initializeExternalAAWrapperPassPass(PassRegistry&); void initializeFEntryInserterPass(PassRegistry&); diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -128,6 +128,13 @@ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3> ]>; +def SDTIntBinOpVP : SDTypeProfile<1, 4, [ // vp_add, vp_and, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTIntShiftOpVP : SDTypeProfile<1, 4, [ // shl, sra, srl + SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; + def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc. SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0> ]>; @@ -164,6 +171,12 @@ def SDTFPToIntOp : SDTypeProfile<1, 1, [ // fp_to_[su]int SDTCisInt<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1> ]>; +def SDTIntToFPOpVP : SDTypeProfile<1, 3, [ // [su]int_to_fp + SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 2> +]>; +def SDTFPToIntOpVP : SDTypeProfile<1, 3, [ // fp_to_[su]int + SDTCisInt<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 2> +]>; def SDTExtInreg : SDTypeProfile<1, 2, [ // sext_inreg SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisVT<2, OtherVT>, SDTCisVTSmallerThanOp<2, 1> @@ -173,9 +186,22 @@ SDTCisOpSmallerThanOp<1, 0> ]>; +def SDTFPUnOpVP : SDTypeProfile<1, 3, [ // vp_fneg, etc. + SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<2>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; +def SDTFPBinOpVP : SDTypeProfile<1, 4, [ // vp_fadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 3>, SDTCisInt<4> +]>; +def SDTFPTernaryOpVP : SDTypeProfile<1, 5, [ // vp_fmadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 4>, SDTCisInt<5> +]>; + def SDTSetCC : SDTypeProfile<1, 3, [ // setcc SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT> ]>; +def SDTSetCCVP : SDTypeProfile<1, 5, [ // vp_setcc + SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 4>, SDTCisInt<5> +]>; def SDTSelect : SDTypeProfile<1, 3, [ // select SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3> @@ -185,6 +211,10 @@ SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1> ]>; +def SDTSelectVP : SDTypeProfile<1, 4, [ // vp_select + SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<4> +]>; + def SDTSelectCC : SDTypeProfile<1, 5, [ // select_cc SDTCisSameAs<1, 2>, SDTCisSameAs<3, 4>, SDTCisSameAs<0, 3>, SDTCisVT<5, OtherVT> @@ -233,9 +263,30 @@ SDTCisSameNumEltsAs<0, 3> ]>; +def SDTStoreVP: SDTypeProfile<0, 4, [ // vp store + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; + +// scatter (Value, BasePtr, Index, Scale, Mask, Vlen) +def SDTScatterVP: SDTypeProfile<0, 6, [ // vp scatter + SDTCisVec<0>, SDTCisInt<1>, SDTCisVec<2>, SDTCisInt<3>, SDTCisVec<4>, SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<2, 4>, SDTCisInt<5> +]>; + +// gather (BasePtr, Index, Scale, Mask, Vlen) +def SDTGatherVP: SDTypeProfile<1, 5, [ // vp gather + SDTCisVec<0>, SDTCisInt<1>, SDTCisVec<2>, SDTCisInt<3>, SDTCisVec<4>, SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<2, 4>, SDTCisInt<5> +]>; + +def SDTLoadVP : SDTypeProfile<1, 3, [ // vp load + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; +def SDTVShiftVP : SDTypeProfile<1, 4, [ + SDTCisSameAs<0, 1>, SDTCisVec<0>, SDTCisVec<3>, SDTCisSameNumEltsAs<3,1>, SDTCisInt<4> +]>; def SDTVecExtract : SDTypeProfile<1, 2, [ // vector extract SDTCisEltOfVec<0, 1>, SDTCisPtrTy<2> ]>; @@ -245,6 +296,12 @@ def SDTVecReduce : SDTypeProfile<1, 1, [ // vector reduction SDTCisInt<0>, SDTCisVec<1> ]>; +def SDTReduceVP : SDTypeProfile<1, 3, [ // vp_reduce (no start arg) + SDTCisVec<1>, SDTCisInt<2>, SDTCisVec<2>, SDTCisInt<3>, SDTCisSameNumEltsAs<1,2> +]>; +def SDTReduceStartVP : SDTypeProfile<1, 4, [ // vp_reduce (with start arg) + SDTCisVec<2>, SDTCisInt<3>, SDTCisVec<3>, SDTCisInt<4>, SDTCisSameNumEltsAs<2,3> +]>; def SDTSubVecExtract : SDTypeProfile<1, 2, [// subvector extract SDTCisSubVecOfVec<0,1>, SDTCisInt<2> @@ -392,6 +449,22 @@ def umax : SDNode<"ISD::UMAX" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; +// TODO SDNPCommutative/SDNPAssociative for VP operators. +def vp_and : SDNode<"ISD::VP_AND" , SDTIntBinOpVP>; +def vp_or : SDNode<"ISD::VP_OR" , SDTIntBinOpVP>; +def vp_xor : SDNode<"ISD::VP_XOR" , SDTIntBinOpVP>; +def vp_srl : SDNode<"ISD::VP_SRL" , SDTIntShiftOpVP>; +def vp_sra : SDNode<"ISD::VP_SRA" , SDTIntShiftOpVP>; +def vp_shl : SDNode<"ISD::VP_SHL" , SDTIntShiftOpVP>; + +def vp_add : SDNode<"ISD::VP_ADD" , SDTIntBinOpVP>; +def vp_sub : SDNode<"ISD::VP_SUB" , SDTIntBinOpVP>; +def vp_mul : SDNode<"ISD::VP_MUL" , SDTIntBinOpVP>; +def vp_sdiv : SDNode<"ISD::VP_SDIV" , SDTIntBinOpVP>; +def vp_udiv : SDNode<"ISD::VP_UDIV" , SDTIntBinOpVP>; +def vp_srem : SDNode<"ISD::VP_SREM" , SDTIntBinOpVP>; +def vp_urem : SDNode<"ISD::VP_UREM" , SDTIntBinOpVP>; + def saddsat : SDNode<"ISD::SADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def uaddsat : SDNode<"ISD::UADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def ssubsat : SDNode<"ISD::SSUBSAT" , SDTIntBinOp>; @@ -433,6 +506,20 @@ def vecreduce_smin : SDNode<"ISD::VECREDUCE_SMIN", SDTVecReduce>; def vecreduce_umin : SDNode<"ISD::VECREDUCE_UMIN", SDTVecReduce>; +def vp_reduce_add : SDNode<"ISD::VP_REDUCE_ADD", SDTReduceVP>; +def vp_reduce_smax : SDNode<"ISD::VP_REDUCE_SMAX", SDTReduceVP>; +def vp_reduce_umax : SDNode<"ISD::VP_REDUCE_UMAX", SDTReduceVP>; +def vp_reduce_smin : SDNode<"ISD::VP_REDUCE_SMIN", SDTReduceVP>; +def vp_reduce_umin : SDNode<"ISD::VP_REDUCE_UMIN", SDTReduceVP>; +def vp_reduce_and : SDNode<"ISD::VP_REDUCE_AND", SDTReduceVP>; +def vp_reduce_or : SDNode<"ISD::VP_REDUCE_OR", SDTReduceVP>; +def vp_reduce_xor : SDNode<"ISD::VP_REDUCE_XOR", SDTReduceVP>; + +def vp_reduce_fadd : SDNode<"ISD::VP_REDUCE_FADD", SDTReduceStartVP>; +def vp_reduce_fmul : SDNode<"ISD::VP_REDUCE_FMUL", SDTReduceStartVP>; +def vp_reduce_fmin : SDNode<"ISD::VP_REDUCE_FMIN", SDTReduceVP>; +def vp_reduce_fmax : SDNode<"ISD::VP_REDUCE_FMAX", SDTReduceVP>; + def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>; def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>; def fmul : SDNode<"ISD::FMUL" , SDTFPBinOp, [SDNPCommutative]>; @@ -478,6 +565,18 @@ def fpextend : SDNode<"ISD::FP_EXTEND" , SDTFPExtendOp>; def fcopysign : SDNode<"ISD::FCOPYSIGN" , SDTFPSignOp>; +// vector predication + +def vp_fneg : SDNode<"ISD::VP_FNEG" , SDTFPUnOpVP>; +def vp_fadd : SDNode<"ISD::VP_FADD" , SDTFPBinOpVP>; +def vp_fsub : SDNode<"ISD::VP_FSUB" , SDTFPBinOpVP>; +def vp_fmul : SDNode<"ISD::VP_FMUL" , SDTFPBinOpVP>; +def vp_fdiv : SDNode<"ISD::VP_FDIV" , SDTFPBinOpVP>; +def vp_frem : SDNode<"ISD::VP_FREM" , SDTFPBinOpVP>; +def vp_fminnum : SDNode<"ISD::VP_FMINNUM" , SDTFPBinOpVP>; +def vp_fmaxnum : SDNode<"ISD::VP_FMAXNUM" , SDTFPBinOpVP>; +def vp_fma : SDNode<"ISD::VP_FMA" , SDTFPTernaryOpVP>; + def sint_to_fp : SDNode<"ISD::SINT_TO_FP" , SDTIntToFPOp>; def uint_to_fp : SDNode<"ISD::UINT_TO_FP" , SDTIntToFPOp>; def fp_to_sint : SDNode<"ISD::FP_TO_SINT" , SDTFPToIntOp>; @@ -485,6 +584,11 @@ def f16_to_fp : SDNode<"ISD::FP16_TO_FP" , SDTIntToFPOp>; def fp_to_f16 : SDNode<"ISD::FP_TO_FP16" , SDTFPToIntOp>; +def vp_sint_to_fp : SDNode<"ISD::VP_SINT_TO_FP" , SDTIntToFPOpVP>; +def vp_uint_to_fp : SDNode<"ISD::VP_UINT_TO_FP" , SDTIntToFPOpVP>; +def vp_fp_to_sint : SDNode<"ISD::VP_FP_TO_SINT" , SDTFPToIntOpVP>; +def vp_fp_to_uint : SDNode<"ISD::VP_FP_TO_UINT" , SDTFPToIntOpVP>; + def strict_fadd : SDNode<"ISD::STRICT_FADD", SDTFPBinOp, [SDNPHasChain, SDNPCommutative]>; def strict_fsub : SDNode<"ISD::STRICT_FSUB", @@ -555,8 +659,10 @@ SDTIntToFPOp, [SDNPHasChain]>; def setcc : SDNode<"ISD::SETCC" , SDTSetCC>; +def vp_setcc : SDNode<"ISD::VP_SETCC" , SDTSetCCVP>; def select : SDNode<"ISD::SELECT" , SDTSelect>; def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>; +def vp_select : SDNode<"ISD::VP_SELECT" , SDTSelectVP>; def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>; def brcc : SDNode<"ISD::BR_CC" , SDTBrCC, [SDNPHasChain]>; @@ -623,6 +729,16 @@ def masked_ld : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def vp_store : SDNode<"ISD::VP_STORE", SDTStoreVP, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def vp_load : SDNode<"ISD::VP_LOAD", SDTLoadVP, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + +def vp_scatter : SDNode<"ISD::VP_SCATTER", SDTScatterVP, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def vp_gather : SDNode<"ISD::VP_GATHER", SDTGatherVP, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). def ld : SDNode<"ISD::LOAD" , SDTLoad, @@ -631,7 +747,7 @@ [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def ist : SDNode<"ISD::STORE" , SDTIStore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; - +def vp_vshift : SDNode<"ISD::VP_VSHIFT", SDTVShiftVP, []>; def vector_shuffle : SDNode<"ISD::VECTOR_SHUFFLE", SDTVecShuffle, []>; def build_vector : SDNode<"ISD::BUILD_VECTOR", SDTypeProfile<1, -1, []>, []>; def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>, diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -4672,8 +4673,10 @@ /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value *SimplifyFSubInstGeneric(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, MatchContext & MC) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; @@ -4681,26 +4684,26 @@ return C; // fsub X, +0 ==> X - if (match(Op1, m_PosZeroFP())) + if (MC.try_match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZeroFP()) && + if (MC.try_match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X // fsub -0.0, (fneg X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FNeg(m_Value(X)))) + if (MC.try_match(Op0, m_NegZeroFP()) && + MC.try_match(Op1, m_FNeg(m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || - match(Op1, m_FNeg(m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || + MC.try_match(Op1, m_FNeg(m_Value(X))))) return X; // fsub nnan x, x ==> 0.0 @@ -4710,8 +4713,8 @@ // Y - (Y - X) --> X // (X + Y) - Y --> X if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || - match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + MC.try_match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) return X; return nullptr; @@ -4766,9 +4769,26 @@ } + +/// Given operands for an FSub, see if we can fold the result. +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; + + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, EC); +} + Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); + // Now apply simplifications that do not require rounding. + return SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +Value *llvm::SimplifyPredicatedFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, PredicatedContext & PC) { + return ::SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, PC); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -5390,9 +5410,20 @@ return ::SimplifyFreezeInst(Op0, Q); } +Value *llvm::SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q) { + PredicatedContext PC(&VPInst); + + auto & PI = cast(VPInst); + switch (PI.getOpcode()) { + default: + return nullptr; + + case Instruction::FSub: return SimplifyPredicatedFSubInst(VPInst.getOperand(0), VPInst.getOperand(1), VPInst.getFastMathFlags(), Q, PC); + } +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. - Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); @@ -5529,6 +5560,13 @@ Result = SimplifyPHINode(cast(I), Q); break; case Instruction::Call: { + auto * VPInst = dyn_cast(I); + if (VPInst) { + Result = SimplifyVPIntrinsic(*VPInst, Q); + if (Result) break; + } + + CallSite CS((I)); Result = SimplifyCall(cast(I), Q); // Don't perform known bits simplification below for musttail calls. if (cast(I)->isMustTailCall()) diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -848,6 +848,16 @@ return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); } +bool TargetTransformInfo::shouldFoldVectorLengthIntoMask( + const PredicatedInstruction &PI) const { + return TTIImpl->shouldFoldVectorLengthIntoMask(PI); +} + +bool TargetTransformInfo::supportsVPOperation( + const PredicatedInstruction &PI) const { + return TTIImpl->supportsVPOperation(PI); +} + bool TargetTransformInfo::useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags Flags) const { return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -646,6 +646,7 @@ KEYWORD(inlinehint); KEYWORD(inreg); KEYWORD(jumptable); + KEYWORD(mask); KEYWORD(minsize); KEYWORD(naked); KEYWORD(nest); @@ -667,6 +668,7 @@ KEYWORD(optforfuzzing); KEYWORD(optnone); KEYWORD(optsize); + KEYWORD(passthru); KEYWORD(readnone); KEYWORD(readonly); KEYWORD(returned); @@ -690,6 +692,7 @@ KEYWORD(swiftself); KEYWORD(uwtable); KEYWORD(willreturn); + KEYWORD(vlen); KEYWORD(writeonly); KEYWORD(zeroext); KEYWORD(immarg); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1341,15 +1341,18 @@ case lltok::kw_dereferenceable: case lltok::kw_dereferenceable_or_null: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_noalias: case lltok::kw_nocapture: case lltok::kw_nonnull: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute on a function"); @@ -1636,11 +1639,13 @@ } case lltok::kw_inalloca: B.addAttribute(Attribute::InAlloca); break; case lltok::kw_inreg: B.addAttribute(Attribute::InReg); break; + case lltok::kw_mask: B.addAttribute(Attribute::Mask); break; case lltok::kw_nest: B.addAttribute(Attribute::Nest); break; case lltok::kw_noalias: B.addAttribute(Attribute::NoAlias); break; case lltok::kw_nocapture: B.addAttribute(Attribute::NoCapture); break; case lltok::kw_nofree: B.addAttribute(Attribute::NoFree); break; case lltok::kw_nonnull: B.addAttribute(Attribute::NonNull); break; + case lltok::kw_passthru: B.addAttribute(Attribute::Passthru); break; case lltok::kw_readnone: B.addAttribute(Attribute::ReadNone); break; case lltok::kw_readonly: B.addAttribute(Attribute::ReadOnly); break; case lltok::kw_returned: B.addAttribute(Attribute::Returned); break; @@ -1648,6 +1653,7 @@ case lltok::kw_sret: B.addAttribute(Attribute::StructRet); break; case lltok::kw_swifterror: B.addAttribute(Attribute::SwiftError); break; case lltok::kw_swiftself: B.addAttribute(Attribute::SwiftSelf); break; + case lltok::kw_vlen: B.addAttribute(Attribute::VectorLength); break; case lltok::kw_writeonly: B.addAttribute(Attribute::WriteOnly); break; case lltok::kw_zeroext: B.addAttribute(Attribute::ZExt); break; case lltok::kw_immarg: B.addAttribute(Attribute::ImmArg); break; @@ -1740,13 +1746,16 @@ // Error handling. case lltok::kw_byval: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_nocapture: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute"); break; diff --git a/llvm/lib/AsmParser/LLToken.h b/llvm/lib/AsmParser/LLToken.h --- a/llvm/lib/AsmParser/LLToken.h +++ b/llvm/lib/AsmParser/LLToken.h @@ -192,6 +192,7 @@ kw_inlinehint, kw_inreg, kw_jumptable, + kw_mask, kw_minsize, kw_naked, kw_nest, @@ -213,6 +214,7 @@ kw_optforfuzzing, kw_optnone, kw_optsize, + kw_passthru, kw_readnone, kw_readonly, kw_returned, @@ -233,6 +235,7 @@ kw_swiftself, kw_uwtable, kw_willreturn, + kw_vlen, kw_writeonly, kw_zeroext, kw_immarg, diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1304,6 +1304,15 @@ case Attribute::SanitizeMemTag: llvm_unreachable("sanitize_memtag attribute not supported in raw format"); break; + case Attribute::Mask: + llvm_unreachable("mask attribute not supported in raw format"); + break; + case Attribute::VectorLength: + llvm_unreachable("vlen attribute not supported in raw format"); + break; + case Attribute::Passthru: + llvm_unreachable("passthru attribute not supported in raw format"); + break; } llvm_unreachable("Unsupported attribute type"); } @@ -1318,6 +1327,9 @@ I == Attribute::DereferenceableOrNull || I == Attribute::ArgMemOnly || I == Attribute::AllocSize || + I == Attribute::Mask || + I == Attribute::VectorLength || + I == Attribute::Passthru || I == Attribute::NoSync) continue; if (uint64_t A = (Val & getRawAttributeMask(I))) { @@ -1443,6 +1455,8 @@ return Attribute::InReg; case bitc::ATTR_KIND_JUMP_TABLE: return Attribute::JumpTable; + case bitc::ATTR_KIND_MASK: + return Attribute::Mask; case bitc::ATTR_KIND_MIN_SIZE: return Attribute::MinSize; case bitc::ATTR_KIND_NAKED: @@ -1491,6 +1505,8 @@ return Attribute::OptimizeForSize; case bitc::ATTR_KIND_OPTIMIZE_NONE: return Attribute::OptimizeNone; + case bitc::ATTR_KIND_PASSTHRU: + return Attribute::Passthru; case bitc::ATTR_KIND_READ_NONE: return Attribute::ReadNone; case bitc::ATTR_KIND_READ_ONLY: @@ -1537,6 +1553,8 @@ return Attribute::UWTable; case bitc::ATTR_KIND_WILLRETURN: return Attribute::WillReturn; + case bitc::ATTR_KIND_VECTORLENGTH: + return Attribute::VectorLength; case bitc::ATTR_KIND_WRITEONLY: return Attribute::WriteOnly; case bitc::ATTR_KIND_Z_EXT: diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -678,6 +678,12 @@ return bitc::ATTR_KIND_READ_ONLY; case Attribute::Returned: return bitc::ATTR_KIND_RETURNED; + case Attribute::Mask: + return bitc::ATTR_KIND_MASK; + case Attribute::VectorLength: + return bitc::ATTR_KIND_VECTORLENGTH; + case Attribute::Passthru: + return bitc::ATTR_KIND_PASSTHRU; case Attribute::ReturnsTwice: return bitc::ATTR_KIND_RETURNS_TWICE; case Attribute::SExt: diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -25,6 +25,7 @@ ExpandMemCmp.cpp ExpandPostRAPseudos.cpp ExpandReductions.cpp + ExpandVectorPredication.cpp FaultMaps.cpp FEntryInserter.cpp FinalizeISel.cpp diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -0,0 +1,688 @@ +//===--- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements IR expansion for vector predication intrinsics, allowing +// targets to enable vector predication until just before codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ExpandVectorPredication.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "expand-vec-pred" + +STATISTIC(NumFoldedVL, "Number of folded vector length params"); +STATISTIC(numLoweredVPOps, "Number of folded vector predication operations"); + +/// \returns Whether the vector mask \p MaskVal has all lane bits set. +static bool IsAllTrueMask(Value *MaskVal) { + auto ConstVec = dyn_cast(MaskVal); + if (!ConstVec) + return false; + return ConstVec->isAllOnesValue(); +} + +/// \returns The constant \p ConstVal broadcasted to \p VecTy. +static Value *BroadcastConstant(Constant *ConstVal, VectorType *VecTy) { + return ConstantDataVector::getSplat(VecTy->getVectorNumElements(), ConstVal); +} + +/// \returns The neutral element of the reduction \p VPRedID. +static Value *GetNeutralElementVector(Intrinsic::ID VPRedID, + VectorType *VecTy) { + unsigned ElemBits = VecTy->getScalarSizeInBits(); + + switch (VPRedID) { + default: + abort(); // invalid vp reduction intrinsic + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_umax: + return Constant::getNullValue(VecTy); + + case Intrinsic::vp_reduce_mul: + return BroadcastConstant( + ConstantInt::get(VecTy->getElementType(), 1, false), VecTy); + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_umin: + return Constant::getAllOnesValue(VecTy); + + case Intrinsic::vp_reduce_smin: + return BroadcastConstant( + ConstantInt::get(VecTy->getContext(), + APInt::getSignedMaxValue(ElemBits)), + VecTy); + case Intrinsic::vp_reduce_smax: + return BroadcastConstant( + ConstantInt::get(VecTy->getContext(), + APInt::getSignedMinValue(ElemBits)), + VecTy); + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return BroadcastConstant(ConstantFP::getQNaN(VecTy->getElementType()), + VecTy); + case Intrinsic::vp_reduce_fadd: + return BroadcastConstant(ConstantFP::get(VecTy->getElementType(), 0.0), + VecTy); + case Intrinsic::vp_reduce_fmul: + return BroadcastConstant(ConstantFP::get(VecTy->getElementType(), 1.0), + VecTy); + } +} + +namespace { + +/// \brief The logical vector element size of this operation. +int32_t GetFunctionalVectorElementSize() { + return 64; // TODO infer from operation (eg + // VPIntrinsic::getVectorElementSize()) +} + +/// \returns A vector with ascending integer indices (<0, 1, ..., NumElems-1>). +Value *CreateStepVector(IRBuilder<> &Builder, int32_t ElemBits, + int32_t NumElems) { + // TODO add caching + SmallVector ConstElems; + + Type *LaneTy = Builder.getIntNTy(ElemBits); + + for (int32_t Idx = 0; Idx < NumElems; ++Idx) { + ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); + } + + return ConstantVector::get(ConstElems); +} + +/// \returns A bitmask that is true where the lane position is less-than +/// +/// \p Builder +/// Used for instruction creation. +/// \p VLParam +/// The explicit vector length parameter to test against the lane +/// positions. +// \p ElemBits +/// Integer bitsize used for the generated ICmp instruction. +/// \p NumElems +/// Static vector length of the operation. +Value *ConvertVLToMask(IRBuilder<> &Builder, Value *VLParam, int32_t ElemBits, + int32_t NumElems) { + // TODO increase elem bits to shrink wrap VLParam where necessary (eg if + // operating on i8) + Type *LaneTy = Builder.getIntNTy(ElemBits); + + auto ExtVLParam = Builder.CreateSExt(VLParam, LaneTy); + auto VLSplat = Builder.CreateVectorSplat(NumElems, ExtVLParam); + + auto IdxVec = CreateStepVector(Builder, ElemBits, NumElems); + + return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); +} + +/// \returns A non-excepting divisor constant for this type. +Constant *GetSafeDivisor(Type *DivTy) { + if (DivTy->isIntOrIntVectorTy()) { + return Constant::getAllOnesValue(DivTy); + } + if (DivTy->isFPOrFPVectorTy()) { + return ConstantVector::getSplat( + DivTy->getVectorElementCount(), + ConstantFP::get(DivTy->getVectorElementType(), 1.0)); + } + llvm_unreachable("Not a valid type for division"); +} + +/// Transfer operation properties from \p OldVPI to \p NewVal. +void TransferDecorations(Value *NewVal, VPIntrinsic *OldVPI) { + auto NewInst = dyn_cast(NewVal); + if (!NewInst || !isa(NewVal)) + return; + + auto OldFMOp = dyn_cast(OldVPI); + if (!OldFMOp) + return; + + NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); +} + +/// Transfer all properties from \p OldOp to \p NewOp and replace all uses. +/// OldVP gets erased. +void ReplaceOperation(Value *NewOp, VPIntrinsic *OldOp) { + TransferDecorations(NewOp, OldOp); + OldOp->replaceAllUsesWith(NewOp); + OldOp->eraseFromParent(); +} + +/// \brief Lower this vector-predicated operator into standard IR. +void LowerVPUnaryOperator(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + auto OC = VPI->getFunctionalOpcode(); + auto FirstOp = VPI->getOperand(0); + assert(OC == Instruction::FNeg); + auto I = cast(VPI); + IRBuilder<> Builder(I); + auto NewFNeg = Builder.CreateFNegFMF(FirstOp, I, I->getName()); + ReplaceOperation(NewFNeg, VPI); +} + +/// \brief Lower this VP binary operator to a non-VP binary operator. +void LowerVPBinaryOperator(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + assert(VPI->isBinaryOp()); + + auto OldBinOp = cast(VPI); + + auto FirstOp = VPI->getOperand(0); + auto SndOp = VPI->getOperand(1); + + IRBuilder<> Builder(OldBinOp); + auto Mask = VPI->getMaskParam(); + + // Blend in safe operands + if (!IsAllTrueMask(Mask)) { + switch (VPI->getFunctionalOpcode()) { + default: + // can safely ignore the predicate + break; + + // Division operators need a safe divisor on masked-off lanes (1.0) + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // 2nd operand must not be zero + auto SafeDivisor = GetSafeDivisor(VPI->getType()); + SndOp = Builder.CreateSelect(Mask, SndOp, SafeDivisor); + } + } + + auto NewBinOp = Builder.CreateBinOp( + static_cast(VPI->getFunctionalOpcode()), FirstOp, + SndOp, VPI->getName(), nullptr); + + ReplaceOperation(NewBinOp, VPI); +} + +/// \brief Lower this vector-predicated cast operator. +void LowerVPCastOperator(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + assert(!VPI->isConstrainedOp()); + auto OC = VPI->getFunctionalOpcode(); + IRBuilder<> Builder(cast(VPI)); + auto NewCast = + Builder.CreateCast(static_cast(OC), + VPI->getArgOperand(0), VPI->getType(), VPI->getName()); + + ReplaceOperation(NewCast, VPI); +} + +/// \brief Lower llvm.vp.compose.* into a select instruction +void LowerVPCompose(VPIntrinsic *VPI) { + auto ElemBits = GetFunctionalVectorElementSize(); + ElementCount ElemCount = VPI->getStaticVectorLength(); + assert(!ElemCount.Scalable && "TODO scalable type support"); + + IRBuilder<> Builder(cast(VPI)); + auto PivotMask = + ConvertVLToMask(Builder, VPI->getOperand(2), ElemBits, ElemCount.Min); + auto NewCompose = Builder.CreateSelect(PivotMask, VPI->getOperand(0), + VPI->getOperand(1), VPI->getName()); + + ReplaceOperation(NewCompose, VPI); +} + +/// \brief Lower this llvm.vp.fma intrinsic to a llvm.fma intrinsic. +void LowerToIntrinsic(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + + auto I = cast(VPI); + auto M = I->getParent()->getModule(); + IRBuilder<> Builder(I); + Intrinsic::ID IID = VPI->getFunctionalIntrinsicID(); + assert(IID != Intrinsic::not_intrinsic && "cannot lower to non-VP intrinsic"); + assert(!VPI->isConstrainedOp() && + "TODO implement lowering to constrained fp"); + assert(!VPIntrinsic::IsVPIntrinsic(IID)); + + SmallVector IntrinTypeVec; + IntrinTypeVec.push_back(VPI->getType()); // TODO simplify + + // Implicitly assumes that the return type is sufficient for disambiguation. + Function *IntrinFunc = Intrinsic::getDeclaration(M, IID, IntrinTypeVec); + assert(IntrinFunc); + + LLVM_DEBUG(dbgs() << "Using " << *IntrinFunc << " to lower " + << VPI->getCalledFunction() << "\n"); + + // Construct argument vector. + assert(!IntrinFunc->getFunctionType()->isVarArg()); + unsigned NumIntrinParams = IntrinFunc->getFunctionType()->getNumParams(); + SmallVector IntrinArgs; + for (unsigned i = 0; i < NumIntrinParams; ++i) { + IntrinArgs.push_back(VPI->getArgOperand(i)); + } + + auto NewIntrin = Builder.CreateCall(IntrinFunc, IntrinArgs, VPI->getName()); + + ReplaceOperation(NewIntrin, VPI); +} + +/// \brief Lower this llvm.vp.reduce.* intrinsic to a llvm.experimental.reduce.* +/// intrinsic. +void LowerVPReduction(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + assert(VPI->isReductionOp()); + + auto &I = *cast(VPI); + IRBuilder<> Builder(&I); + auto M = Builder.GetInsertBlock()->getModule(); + assert(M && "No module to declare reduction intrinsic in!"); + + SmallVector Args; + + Value *RedVectorParam = VPI->getReductionVectorParam(); + Value *RedAccuParam = VPI->getReductionAccuParam(); + Value *MaskParam = VPI->getMaskParam(); + auto FunctionalID = VPI->getFunctionalIntrinsicID(); + + // Insert neutral element in masked-out positions + bool IsUnmasked = IsAllTrueMask(VPI->getMaskParam()); + if (!IsUnmasked) { + auto *NeutralVector = GetNeutralElementVector( + VPI->getIntrinsicID(), cast(RedVectorParam->getType())); + RedVectorParam = + Builder.CreateSelect(MaskParam, RedVectorParam, NeutralVector); + } + + auto VecTypeArg = RedVectorParam->getType(); + + Value *NewReduct; + switch (FunctionalID) { + default: { + auto RedIntrinFunc = Intrinsic::getDeclaration(M, FunctionalID, VecTypeArg); + NewReduct = Builder.CreateCall(RedIntrinFunc, RedVectorParam, I.getName()); + assert(!RedAccuParam && "accu dropped"); + } break; + + case Intrinsic::experimental_vector_reduce_v2_fadd: + case Intrinsic::experimental_vector_reduce_v2_fmul: { + auto TypeArg = RedAccuParam->getType(); + auto RedIntrinFunc = + Intrinsic::getDeclaration(M, FunctionalID, {TypeArg, VecTypeArg}); + NewReduct = Builder.CreateCall(RedIntrinFunc, + {RedAccuParam, RedVectorParam}, I.getName()); + } break; + } + + TransferDecorations(NewReduct, VPI); + I.replaceAllUsesWith(NewReduct); + I.eraseFromParent(); +} + +/// \brief Lower this llvm.vp.(load|store|gather|scatter) to a non-vp +/// instruction. +void LowerVPMemoryIntrinsic(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + auto &I = cast(*VPI); + + auto MaskParam = VPI->getMaskParam(); + auto PtrParam = VPI->getMemoryPointerParam(); + auto DataParam = VPI->getMemoryDataParam(); + bool IsUnmasked = IsAllTrueMask(MaskParam); + + IRBuilder<> Builder(&I); + MaybeAlign AlignOpt = VPI->getPointerAlignment(); + + Value *NewMemoryInst = nullptr; + switch (VPI->getIntrinsicID()) { + default: + abort(); // not a VP memory intrinsic + + case Intrinsic::vp_store: { + if (IsUnmasked) { + StoreInst *NewStore = Builder.CreateStore(DataParam, PtrParam, false); + if (AlignOpt.hasValue()) + NewStore->setAlignment(AlignOpt.getValue()); + NewMemoryInst = NewStore; + } else { + NewMemoryInst = Builder.CreateMaskedStore( + DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); + } + } break; + + case Intrinsic::vp_load: { + if (IsUnmasked) { + LoadInst *NewLoad = Builder.CreateLoad(PtrParam, false); + if (AlignOpt.hasValue()) + NewLoad->setAlignment(AlignOpt.getValue()); + NewMemoryInst = NewLoad; + } else { + NewMemoryInst = + Builder.CreateMaskedLoad(PtrParam, AlignOpt.valueOrOne(), MaskParam); + } + } break; + + case Intrinsic::vp_scatter: { + // if (IsUnmasked) { + // StoreInst *NewStore = Builder.CreateStore(DataParam, PtrParam, false); + // if (AlignOpt.hasValue()) NewStore->setAlignment(AlignOpt.getValue()); + // NewMemoryInst = NewStore; + // } else { + NewMemoryInst = Builder.CreateMaskedScatter(DataParam, PtrParam, + AlignOpt.valueOrOne(), MaskParam); + // } + } break; + + case Intrinsic::vp_gather: { + // if (IsUnmasked) { + // LoadInst *NewLoad = Builder.CreateLoad(I.getType(), PtrParam, false); + // if (AlignOpt.hasValue()) NewLoad->setAlignment(AlignOpt.getValue()); + // NewMemoryInst = NewLoad; + // } else { + NewMemoryInst = Builder.CreateMaskedGather(PtrParam, AlignOpt.valueOrOne(), + MaskParam, nullptr, I.getName()); + // } + } break; + } + + assert(NewMemoryInst); + ReplaceOperation(NewMemoryInst, VPI); +} + +/// \brief Lower llvm.vp.select.* to a select instruction. +void LowerVPSelectInst(VPIntrinsic *VPI) { + auto I = cast(VPI); + + auto NewSelect = SelectInst::Create(VPI->getMaskParam(), VPI->getOperand(1), + VPI->getOperand(2), I->getName(), I, I); + ReplaceOperation(NewSelect, VPI); +} + +/// \brief Lower llvm.vp.(icmp|fcmp) to an icmp or fcmp instruction. +void LowerVPCompare(VPIntrinsic *VPI) { + auto NewCmp = CmpInst::Create( + static_cast(VPI->getFunctionalOpcode()), + VPI->getCmpPredicate(), VPI->getOperand(0), VPI->getOperand(1), + VPI->getName(), cast(VPI)); + ReplaceOperation(NewCmp, VPI); +} + +/// \brief Try to lower this vp_vshift operation. +bool TryLowerVShift(VPIntrinsic *VPI) { + // vshift(vec, amount, mask, vlen) + + // cannot lower dynamic shift amount + auto *SrcVal = VPI->getArgOperand(0); + auto *AmountVal = VPI->getArgOperand(1); + if (!isa(AmountVal)) + return false; + int64_t Amount = cast(AmountVal)->getSExtValue(); + + // cannot lower scalable vector size + auto ElemCount = VPI->getType()->getVectorElementCount(); + if (ElemCount.Scalable) + return false; + int VecWidth = ElemCount.Min; + + auto IntTy = Type::getInt32Ty(VPI->getContext()); + + // constitute shuffle mask. + std::vector Elems; + for (int i = 0; i < (int)ElemCount.Min; ++i) { + int64_t SrcLane = i - Amount; + if (SrcLane < 0 || SrcLane >= VecWidth) + Elems.push_back(UndefValue::get(IntTy)); + else + Elems.push_back(ConstantInt::get(IntTy, SrcLane)); + } + auto *ShuffleMask = ConstantVector::get(Elems); + + auto *V2 = UndefValue::get(SrcVal->getType()); + + // Translate to a shuffle + auto NewI = new ShuffleVectorInst(SrcVal, V2, ShuffleMask, VPI->getName(), + cast(VPI)); + ReplaceOperation(NewI, VPI); + return true; +} + +/// \brief Lower a llvm.vp.* intrinsic that is not functionally equivalent to a +/// standard IR instruction. +void LowerUnmatchedVPIntrinsic(VPIntrinsic *VPI) { + if (VPI->isReductionOp()) + return LowerVPReduction(VPI); + + switch (VPI->getIntrinsicID()) { + default: + LowerToIntrinsic(VPI); + break; + + // Shuffles + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + if (TryLowerVShift(VPI)) + return; + + LLVM_DEBUG(dbgs() << "Silently keeping VP intrinsic: can not substitute: " + << *VPI << "\n"); + return; + + case Intrinsic::vp_compose: + LowerVPCompose(VPI); + break; + + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + LowerVPMemoryIntrinsic(VPI); + break; + } +} + +/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. +bool expandVectorPredication(Function &F, const TargetTransformInfo *TTI) { + bool Changed = false; + + // Holds all vector-predicated ops with an effective vector length param that + // needs to be folded into the mask param. + SmallVector ExpandVLWorklist; + + // Holds all vector-predicated ops that need to translated into non-VP ops. + SmallVector ExpandOpWorklist; + + for (auto &I : instructions(F)) { + auto *VPI = dyn_cast(&I); + if (!VPI) + continue; + + auto &PI = cast(*VPI); + + bool supportsVPOp = TTI->supportsVPOperation(PI); + bool hasEffectiveVLParam = !VPI->canIgnoreVectorLengthParam(); + bool shouldFoldVLParam = + !supportsVPOp || TTI->shouldFoldVectorLengthIntoMask(PI); + + LLVM_DEBUG(dbgs() << "Inspecting " << *VPI + << "\n:: target-support=" << supportsVPOp + << ", effectiveVecLen=" << hasEffectiveVLParam + << ", shouldFoldVecLen=" << shouldFoldVLParam << "\n"); + + if (shouldFoldVLParam) { + if (hasEffectiveVLParam && VPI->getMaskParam()) { + ExpandVLWorklist.push_back(VPI); + } else { + ExpandOpWorklist.push_back(VPI); + } + } + } + + // Fold vector-length params into the mask param. + LLVM_DEBUG(dbgs() << "\n:::: Folding vlen into mask. ::::\n"); + for (VPIntrinsic *VPI : ExpandVLWorklist) { + ++NumFoldedVL; + Changed = true; + + LLVM_DEBUG(dbgs() << "Folding vlen for op: " << *VPI << '\n'); + + IRBuilder<> Builder(cast(VPI)); + + Value *OldMaskParam = VPI->getMaskParam(); + Value *OldVLParam = VPI->getVectorLengthParam(); + assert(OldMaskParam && "no mask param to fold the vl param into"); + assert(OldVLParam && "no vector length param to fold away"); + + LLVM_DEBUG(dbgs() << "OLD vlen: " << *OldVLParam << '\n'); + LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); + + // Determine the lane bit size that should be used to lower this op + auto ElemBits = GetFunctionalVectorElementSize(); + ElementCount ElemCount = VPI->getStaticVectorLength(); + assert(!ElemCount.Scalable && "TODO scalable vector support"); + + // Lower VL to M + auto *VLMask = + ConvertVLToMask(Builder, OldVLParam, ElemBits, ElemCount.Min); + auto NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); + VPI->setMaskParam( + NewMaskParam); // FIXME cannot trivially use the PI abstraction here. + + // Disable VL + auto FullVL = Builder.getInt32(ElemCount.Min); + VPI->setVectorLengthParam(FullVL); + assert(VPI->canIgnoreVectorLengthParam() && + "transformation did not render the vl param ineffective!"); + + LLVM_DEBUG(dbgs() << "NEW vlen: " << *FullVL << '\n'); + LLVM_DEBUG(dbgs() << "NEW mask: " << *NewMaskParam << '\n'); + + auto &PI = cast(*VPI); + if (!TTI->supportsVPOperation(PI)) { + ExpandOpWorklist.push_back(VPI); + } + } + + // Translate into non-VP ops + LLVM_DEBUG(dbgs() << "\n:::: Lowering VP into non-VP ops ::::\n"); + for (VPIntrinsic *VPI : ExpandOpWorklist) { + ++numLoweredVPOps; + Changed = true; + + LLVM_DEBUG(dbgs() << "Lowering vp op: " << *VPI << '\n'); + + // Try lowering to a LLVM instruction first. + unsigned OC = VPI->getFunctionalOpcode(); +#define FIRST_UNARY_INST(X) unsigned FirstUnOp = X; +#define LAST_UNARY_INST(X) unsigned LastUnOp = X; +#define FIRST_BINARY_INST(X) unsigned FirstBinOp = X; +#define LAST_BINARY_INST(X) unsigned LastBinOp = X; +#define FIRST_CAST_INST(X) unsigned FirstCastOp = X; +#define LAST_CAST_INST(X) unsigned LastCastOp = X; +#include "llvm/IR/Instruction.def" + + if (FirstBinOp <= OC && OC <= LastBinOp) { + LowerVPBinaryOperator(VPI); + continue; + } + if (FirstUnOp <= OC && OC <= LastUnOp) { + LowerVPUnaryOperator(VPI); + continue; + } + if (FirstCastOp <= OC && OC <= LastCastOp) { + LowerVPCastOperator(VPI); + continue; + } + + // Lower to a non-VP intrinsic. + switch (OC) { + default: + abort(); // unexpected intrinsic + + case Instruction::Call: + LowerUnmatchedVPIntrinsic(VPI); + break; + + case Instruction::Select: + LowerVPSelectInst(VPI); + break; + + case Instruction::Store: + case Instruction::Load: + LowerVPMemoryIntrinsic(VPI); + break; + + case Instruction::ICmp: + case Instruction::FCmp: + LowerVPCompare(VPI); + break; + } + } + + return Changed; +} + +class ExpandVectorPredication : public FunctionPass { +public: + static char ID; + ExpandVectorPredication() : FunctionPass(ID) { + initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + const auto *TTI = &getAnalysis().getTTI(F); + return expandVectorPredication(F, TTI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +char ExpandVectorPredication::ID; +INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) + +FunctionPass *llvm::createExpandVectorPredicationPass() { + return new ExpandVectorPredication(); +} + +PreservedAnalyses +ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { + const auto &TTI = AM.getResult(F); + if (!expandVectorPredication(F, &TTI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet(); + return PA; +} diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -438,6 +438,7 @@ SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitFADD_VP(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); SDValue visitFMA(SDNode *N); @@ -486,6 +487,7 @@ SDValue visitFP16_TO_FP(SDNode *N); SDValue visitVECREDUCE(SDNode *N); + template SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -748,6 +750,137 @@ void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; +struct EmptyMatchContext { + SelectionDAG & DAG; + + EmptyMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + {} + + bool match(SDValue OpN, unsigned OpCode) const { return OpCode == OpN->getOpcode(); } + + unsigned getFunctionOpCode(SDValue N) const { + return N->getOpcode(); + } + + bool isCompatible(SDValue OpVal) const { return true; } + + // Specialize based on number of operands. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, Operand, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, Flags); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4, SDValue N5) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4, N5); + } +}; + +struct +VPMatchContext { + SelectionDAG & DAG; + SDNode * Root; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + + VPMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + , Root(Root) + , RootMaskOp() + , RootVectorLenOp() + { + if (Root->isVP()) { + int RootMaskPos = ISD::GetMaskPosVP(Root->getOpcode()); + if (RootMaskPos != -1) { + RootMaskOp = Root->getOperand(RootMaskPos); + } + + int RootVLenPos = ISD::GetVectorLengthPosVP(Root->getOpcode()); + if (RootVLenPos != -1) { + RootVectorLenOp = Root->getOperand(RootVLenPos); + } + } + } + + unsigned getFunctionOpCode(SDValue N) const { + unsigned VPOpCode = N->getOpcode(); + return ISD::GetFunctionOpCodeForVP(VPOpCode, !N->getFlags().hasNoFPExcept()); + } + + bool isCompatible(SDValue OpVal) const { + if (!OpVal->isVP()) { + return !Root->isVP(); + + } else { + unsigned VPOpCode = OpVal->getOpcode(); + int MaskPos = ISD::GetMaskPosVP(VPOpCode); + if (MaskPos != -1 && RootMaskOp != OpVal.getOperand(MaskPos)) { + return false; + } + + int VLenPos = ISD::GetVectorLengthPosVP(VPOpCode); + if (VLenPos != -1 && RootVectorLenOp != OpVal.getOperand(VLenPos)) { + return false; + } + + return true; + } + } + + /// whether \p OpN is a node that is functionally compatible with the NodeType \p OpNodeTy + bool match(SDValue OpVal, unsigned OpNT) const { + return isCompatible(OpVal) && getFunctionOpCode(OpVal) == OpNT; + } + + // Specialize based on number of operands. + // TODO emit VP intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 1 && VLenPos == 2); + + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 2 && VLenPos == 3); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 3 && VLenPos == 4); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -1583,6 +1716,7 @@ case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); case ISD::FADD: return visitFADD(N); + case ISD::VP_FADD: return visitFADD_VP(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); case ISD::FMA: return visitFMA(N); @@ -11652,13 +11786,18 @@ return F.hasAllowContract() || F.hasAllowReassociation(); } + /// Try to perform FMA combining on a given FADD node. +template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + MatchContextClass matcher(DAG, N); + if (!matcher.isCompatible(N0) || !matcher.isCompatible(N1)) return SDValue(); + const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. @@ -11691,8 +11830,8 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || isContractable(N.getNode()); }; @@ -11705,44 +11844,44 @@ // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1, Flags); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0, Flags); } // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if ((N0.getOpcode() == ISD::FP_EXTEND) && matcher.isCompatible(N0.getOperand(0))) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1, Flags); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0, Flags); } } @@ -11751,12 +11890,12 @@ if (Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N0, PreferredFusedOpcode) && + matcher.match(N0.getOperand(2), ISD::FMUL) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), N1, Flags), Flags); @@ -11764,12 +11903,12 @@ // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N1, PreferredFusedOpcode) && + matcher.match(N1.getOperand(2), ISD::FMUL) && N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), N0, Flags), Flags); @@ -11781,15 +11920,15 @@ auto FoldFAddFMAFPExtFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (matcher.match(N0, PreferredFusedOpcode)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -11809,12 +11948,12 @@ auto FoldFAddFPExtFMAFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -12278,6 +12417,15 @@ return SDValue(); } +SDValue DAGCombiner::visitFADD_VP(SDNode *N) { + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12455,7 +12603,7 @@ } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -126,6 +126,13 @@ case ISD::FLT_ROUNDS_: Res = PromoteIntRes_FLT_ROUNDS(N); break; + case ISD::VP_REDUCE_MUL: + case ISD::VP_REDUCE_ADD: + case ISD::VP_REDUCE_AND: + case ISD::VP_REDUCE_XOR: + case ISD::VP_REDUCE_OR: + Res = PromoteIntRes_VP_REDUCE_nostart(N); break; + case ISD::AND: case ISD::OR: case ISD::XOR: @@ -1154,7 +1161,7 @@ } // Handle promotion for the ADDE/SUBE/ADDCARRY/SUBCARRY nodes. Notice that -// the third operand of ADDE/SUBE nodes is carry flag, which differs from +// the third operand of ADDE/SUBE nodes is carry flag, which differs from // the ADDCARRY/SUBCARRY nodes in that the third operand is carry Boolean. SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo) { if (ResNo == 1) @@ -1312,6 +1319,9 @@ return false; } + if (N->isVP()) { + Res = PromoteIntOp_VP(N, OpNo); + } else { switch (N->getOpcode()) { default: #ifndef NDEBUG @@ -1396,6 +1406,7 @@ case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break; } + } // If the result is null, the sub-method took care of registering results etc. if (!Res.getNode()) return false; @@ -1689,6 +1700,25 @@ TruncateStore, N->isCompressingStore()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP(SDNode *N, unsigned OpNo) { + EVT DataVT; + switch (N->getOpcode()) { + default: + DataVT = N->getValueType(0); + break; + + case ISD::VP_STORE: + case ISD::VP_SCATTER: + llvm_unreachable("TODO implement VP memory nodes"); + } + + // TODO assert that \p OpNo is the mask + SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = Mask; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo) { assert(OpNo == 3 && "Only know how to promote the mask!"); @@ -4384,6 +4414,31 @@ return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NOutVT, Op); } +SDValue DAGTypeLegalizer::PromoteIntRes_VP_REDUCE_nostart(SDNode *N) { + SDLoc dl(N); + + SDValue VecVal = N->getOperand(0); + SDValue MaskVal = N->getOperand(1); + SDValue LenVal = N->getOperand(2); + + EVT VecVT = VecVal.getValueType(); + + assert(VecVal.getValueType().isVector() && "Input must be a vector"); + assert(MaskVal.getValueType().isVector() && "Mask must be a vector"); + assert(!LenVal.getValueType().isVector() && "Vector length must be a scalar"); + + EVT OutVT = N->getValueType(0); + EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT); + assert(NOutVT.isScalarInteger() && "Type must be promoted to a scalar integer type"); + EVT NVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), VecVT); + // EVT NVecVT = EVT::getVectorVT(*DAG.getContext(), NOutVT, VecVT.getVectorNumElements(), VecVT.isScalableVector()); + + // extend operand along with result type + SDValue ExtVecVal = (NVecVT == VecVT) ? VecVal : DAG.getNode(ISD::ANY_EXTEND, dl, NVecVT, VecVal); + + return DAG.getNode(N->getOpcode(), dl, NOutVT, {ExtVecVal, MaskVal, LenVal}); +} + SDValue DAGTypeLegalizer::PromoteIntRes_SPLAT_VECTOR(SDNode *N) { SDLoc dl(N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -340,6 +340,9 @@ SDValue PromoteIntRes_VECREDUCE(SDNode *N); SDValue PromoteIntRes_ABS(SDNode *N); + // vp reduction without start value + SDValue PromoteIntRes_VP_REDUCE_nostart(SDNode *N); + // Integer Operand Promotion. bool PromoteIntegerOperand(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_ANY_EXTEND(SDNode *N); @@ -377,6 +380,7 @@ SDValue PromoteIntOp_FIX(SDNode *N); SDValue PromoteIntOp_FPOWI(SDNode *N); SDValue PromoteIntOp_VECREDUCE(SDNode *N); + SDValue PromoteIntOp_VP(SDNode *N, unsigned OpNo); void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -443,6 +443,213 @@ return Result; } +//===----------------------------------------------------------------------===// +// SDNode VP Support +//===----------------------------------------------------------------------===// + +int +ISD::GetMaskPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case ISD::VP_FNEG: + return 1; + + case ISD::VP_ADD: + case ISD::VP_SUB: + case ISD::VP_MUL: + case ISD::VP_SDIV: + case ISD::VP_SREM: + case ISD::VP_UDIV: + case ISD::VP_UREM: + + case ISD::VP_AND: + case ISD::VP_OR: + case ISD::VP_XOR: + case ISD::VP_SHL: + case ISD::VP_SRA: + case ISD::VP_SRL: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + return 2; + + case ISD::VP_FMA: + case ISD::VP_SELECT: + return 3; + + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 1; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + return 2; + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + // (implicit) case ISD::VP_COMPOSE: return -1 + } +} + +int +ISD::GetVectorLengthPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case VP_SELECT: + return 0; + + case VP_FNEG: + return 2; + + case VP_ADD: + case VP_SUB: + case VP_MUL: + case VP_SDIV: + case VP_SREM: + case VP_UDIV: + case VP_UREM: + + case VP_AND: + case VP_OR: + case VP_XOR: + case VP_SHL: + case VP_SRA: + case VP_SRL: + + case VP_FADD: + case VP_FMUL: + case VP_FDIV: + case VP_FREM: + return 3; + + case VP_FMA: + return 4; + + case VP_COMPOSE: + return 3; + + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 2; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + return 3; + + } +} + +unsigned +ISD::GetFunctionOpCodeForVP(unsigned OpCode, bool hasFPExcept) { + switch (OpCode) { + default: return OpCode; + + case VP_SELECT: return ISD::VSELECT; + case VP_ADD: return ISD::ADD; + case VP_SUB: return ISD::SUB; + case VP_MUL: return ISD::MUL; + case VP_SDIV: return ISD::SDIV; + case VP_SREM: return ISD::SREM; + case VP_UDIV: return ISD::UDIV; + case VP_UREM: return ISD::UREM; + + case VP_AND: return ISD::AND; + case VP_OR: return ISD::OR; + case VP_XOR: return ISD::XOR; + case VP_SHL: return ISD::SHL; + case VP_SRA: return ISD::SRA; + case VP_SRL: return ISD::SRL; + + case VP_FNEG: return ISD::FNEG; + case VP_FADD: return hasFPExcept ? ISD::STRICT_FADD : ISD::FADD; + case VP_FSUB: return hasFPExcept ? ISD::STRICT_FSUB : ISD::FSUB; + case VP_FMUL: return hasFPExcept ? ISD::STRICT_FMUL : ISD::FMUL; + case VP_FDIV: return hasFPExcept ? ISD::STRICT_FDIV : ISD::FDIV; + case VP_FREM: return hasFPExcept ? ISD::STRICT_FREM : ISD::FREM; + + case VP_REDUCE_AND: return VECREDUCE_AND; + case VP_REDUCE_OR: return VECREDUCE_OR; + case VP_REDUCE_XOR: return VECREDUCE_XOR; + case VP_REDUCE_ADD: return VECREDUCE_ADD; + case VP_REDUCE_FADD: return VECREDUCE_FADD; + case VP_REDUCE_FMUL: return VECREDUCE_FMUL; + case VP_REDUCE_FMAX: return VECREDUCE_FMAX; + case VP_REDUCE_FMIN: return VECREDUCE_FMIN; + case VP_REDUCE_UMAX: return VECREDUCE_UMAX; + case VP_REDUCE_UMIN: return VECREDUCE_UMIN; + case VP_REDUCE_SMAX: return VECREDUCE_SMAX; + case VP_REDUCE_SMIN: return VECREDUCE_SMIN; + + case VP_STORE: return ISD::MSTORE; + case VP_LOAD: return ISD::MLOAD; + case VP_GATHER: return ISD::MGATHER; + case VP_SCATTER: return ISD::MSCATTER; + + case VP_FMA: return hasFPExcept ? ISD::STRICT_FMA : ISD::FMA; + } +} + +unsigned +ISD::GetVPForFunctionOpCode(unsigned OpCode) { + switch (OpCode) { + default: llvm_unreachable("can not translate this Opcode to VP"); + + case VSELECT: return ISD::VP_SELECT; + case ADD: return ISD::VP_ADD; + case SUB: return ISD::VP_SUB; + case MUL: return ISD::VP_MUL; + case SDIV: return ISD::VP_SDIV; + case SREM: return ISD::VP_SREM; + case UDIV: return ISD::VP_UDIV; + case UREM: return ISD::VP_UREM; + + case AND: return ISD::VP_AND; + case OR: return ISD::VP_OR; + case XOR: return ISD::VP_XOR; + case SHL: return ISD::VP_SHL; + case SRA: return ISD::VP_SRA; + case SRL: return ISD::VP_SRL; + + case FNEG: return ISD::VP_FNEG; + case STRICT_FADD: + case FADD: return ISD::VP_FADD; + case STRICT_FSUB: + case FSUB: return ISD::VP_FSUB; + case STRICT_FMUL: + case FMUL: return ISD::VP_FMUL; + case STRICT_FDIV: + case FDIV: return ISD::VP_FDIV; + case STRICT_FREM: + case FREM: return ISD::VP_FREM; + + case STRICT_FMA: + case FMA: return ISD::VP_FMA; + } +} + + //===----------------------------------------------------------------------===// // SDNode Profile Support //===----------------------------------------------------------------------===// @@ -573,6 +780,34 @@ ID.AddInteger(ST->getPointerInfo().getAddrSpace()); break; } + case ISD::VP_LOAD: { + const VPLoadSDNode *ELD = cast(N); + ID.AddInteger(ELD->getMemoryVT().getRawBits()); + ID.AddInteger(ELD->getRawSubclassData()); + ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_STORE: { + const VPStoreSDNode *EST = cast(N); + ID.AddInteger(EST->getMemoryVT().getRawBits()); + ID.AddInteger(EST->getRawSubclassData()); + ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_GATHER: { + const VPGatherSDNode *EG = cast(N); + ID.AddInteger(EG->getMemoryVT().getRawBits()); + ID.AddInteger(EG->getRawSubclassData()); + ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_SCATTER: { + const VPScatterSDNode *ES = cast(N); + ID.AddInteger(ES->getMemoryVT().getRawBits()); + ID.AddInteger(ES->getRawSubclassData()); + ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + break; + } case ISD::MLOAD: { const MaskedLoadSDNode *MLD = cast(N); ID.AddInteger(MLD->getMemoryVT().getRawBits()); @@ -7121,6 +7356,143 @@ LD->getExtensionType(), LD->isExpandingLoad()); } + +SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy) { + SDVTList VTs = getVTList(VT, MVT::Other); + SDValue Ops[] = { Chain, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ExtTy, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + + +SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating) { + assert(Chain.getValueType() == MVT::Other && + "Invalid chain type"); + SDVTList VTs = getVTList(MVT::Other); + SDValue Ops[] = { Chain, Val, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops); + ID.AddInteger(MemVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, IsTruncating, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + IsTruncating, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 6 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_GATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 7 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_SCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Base, SDValue Offset, SDValue Mask, EVT MemVT, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -768,6 +768,12 @@ void visitIntrinsicCall(const CallInst &I, unsigned Intrinsic); void visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic); void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI); + void visitVectorPredicationIntrinsic(const VPIntrinsic &VPI); + void visitCmpVP(const VPIntrinsic &I); + void visitLoadVP(const CallInst &I); + void visitStoreVP(const CallInst &I); + void visitGatherVP(const CallInst &I); + void visitScatterVP(const CallInst &I); void visitVAStart(const CallInst &I); void visitVAArg(const VAArgInst &I); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4362,6 +4362,46 @@ setValue(&I, StoreNode); } +void SelectionDAGBuilder::visitStoreVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getVPStoreOps = [&](Value* &Ptr, Value* &Mask, Value* &Src0, + Value * &VLen, unsigned & Alignment) { + // llvm.masked.store.*(Src0, Ptr, Mask, VLen) + Src0 = I.getArgOperand(0); + Ptr = I.getArgOperand(1); + Alignment = I.getParamAlignment(1); + Mask = I.getArgOperand(2); + VLen = I.getArgOperand(3); + }; + + Value *PtrOperand, *MaskOperand, *Src0Operand, *VLenOperand; + unsigned Alignment = 0; + getVPStoreOps(PtrOperand, MaskOperand, Src0Operand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue Src0 = getValue(Src0Operand); + SDValue Mask = getValue(MaskOperand); + SDValue VLen = getValue(VLenOperand); + + EVT VT = Src0.getValueType(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + SDValue StoreNode = DAG.getStoreVP(getRoot(), sdl, Src0, Ptr, Mask, VLen, VT, + MMO, false /* Truncating */); + DAG.setRoot(StoreNode); + setValue(&I, StoreNode); +} + // Get a uniform base for the Gather/Scatter intrinsic. // The first argument of the Gather/Scatter intrinsic is a vector of pointers. // We try to represent it as a base pointer + vector of indices. @@ -4629,6 +4669,162 @@ setValue(&I, Gather); } +void SelectionDAGBuilder::visitGatherVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // @llvm.evl.gather.*(Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(0); + SDValue Mask = getValue(I.getArgOperand(1)); + SDValue VLen = getValue(I.getArgOperand(2)); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + unsigned Alignment = I.getParamAlignment(0); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + SDValue Root = DAG.getRoot(); + SDValue Base; + SDValue Index; + ISD::MemIndexType IndexType; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); + bool ConstantMemory = false; + if (UniformBase && AA && + AA->pointsToConstantMemory( + MemoryLocation(BasePtr, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo))) { + // Do not serialize (non-volatile) loads of constant memory with anything. + Root = DAG.getEntryNode(); + ConstantMemory = true; + } + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + IndexType = ISD::SIGNED_SCALED; + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { Root, Base, Index, Scale, Mask, VLen }; + SDValue Gather = DAG.getGatherVP(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO, IndexType); + + SDValue OutChain = Gather.getValue(1); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); + setValue(&I, Gather); +} + +void SelectionDAGBuilder::visitScatterVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // llvm.evl.scatter.*(Src0, Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(1); + SDValue Src0 = getValue(I.getArgOperand(0)); + SDValue Mask = getValue(I.getArgOperand(2)); + SDValue VLen = getValue(I.getArgOperand(3)); + EVT VT = Src0.getValueType(); + unsigned Alignment = I.getParamAlignment(1); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + SDValue Base; + SDValue Index; + ISD::MemIndexType IndexType; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); + + const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + IndexType = ISD::SIGNED_SCALED; + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { getRoot(), Src0, Base, Index, Scale, Mask, VLen }; + SDValue Scatter = DAG.getScatterVP(DAG.getVTList(MVT::Other), VT, sdl, + Ops, MMO, IndexType); + DAG.setRoot(Scatter); + setValue(&I, Scatter); +} + +void SelectionDAGBuilder::visitLoadVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getMaskedLoadOps = [&](Value* &Ptr, Value* &Mask, Value* &VLen, + unsigned& Alignment) { + // @llvm.evl.load.*(Ptr, Mask, Vlen) + Ptr = I.getArgOperand(0); + Alignment = I.getParamAlignment(0); + Mask = I.getArgOperand(1); + VLen = I.getArgOperand(2); + }; + + Value *PtrOperand, *MaskOperand, *VLenOperand; + unsigned Alignment; + getMaskedLoadOps(PtrOperand, MaskOperand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue VLen = getValue(VLenOperand); + SDValue Mask = getValue(MaskOperand); + + // infer the return type + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SmallVector ValValueVTs; + ComputeValueVTs(TLI, DAG.getDataLayout(), I.getType(), ValValueVTs); + EVT VT = ValValueVTs[0]; + assert((ValValueVTs.size() == 1) && "splitting not implemented"); + + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + // Do not serialize masked loads of constant memory with anything. + bool AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + SDValue Load = DAG.getLoadVP(VT, sdl, InChain, Ptr, Mask, VLen, VT, MMO, + ISD::NON_EXTLOAD); + if (AddToChain) + PendingLoads.push_back(Load.getValue(1)); + setValue(&I, Load); +} + void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) { SDLoc dl = getCurSDLoc(); AtomicOrdering SuccessOrdering = I.getSuccessOrdering(); @@ -6296,6 +6492,12 @@ #include "llvm/IR/ConstrainedOps.def" visitConstrainedFPIntrinsic(cast(I)); return; + +#define REGISTER_VP_INTRINSIC(VPID,MASKPOS,VLENPOS) case Intrinsic::VPID: +#include "llvm/IR/VPIntrinsics.def" + visitVectorPredicationIntrinsic(cast(I)); + return; + case Intrinsic::fmuladd: { EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict && @@ -7137,6 +7339,112 @@ setValue(&FPI, FPResult); } +void SelectionDAGBuilder::visitCmpVP(const VPIntrinsic &I) { + ISD::CondCode Condition; + CmpInst::Predicate predicate = I.getCmpPredicate(); + bool IsFP = I.getOperand(0)->getType()->isFPOrFPVectorTy(); + if (IsFP) { + Condition = getFCmpCondCode(predicate); + auto *FPMO = dyn_cast(&I); + if ((FPMO && FPMO->hasNoNaNs()) || TM.Options.NoNaNsFPMath) + Condition = getFCmpCodeWithoutNaN(Condition); + + } else { + Condition = getICmpCondCode(predicate); + } + + SDValue Op1 = getValue(I.getOperand(0)); + SDValue Op2 = getValue(I.getOperand(1)); + // #2 is the condition code + SDValue MaskOp = getValue(I.getOperand(3)); + SDValue LenOp = getValue(I.getOperand(4)); + + EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(), + I.getType()); + setValue(&I, DAG.getVPSetCC(getCurSDLoc(), DestVT, Op1, Op2, Condition, MaskOp, LenOp)); +} + +void SelectionDAGBuilder::visitVectorPredicationIntrinsic( + const VPIntrinsic &VPIntrin) { + SDLoc sdl = getCurSDLoc(); + unsigned Opcode; + switch (VPIntrin.getIntrinsicID()) { + default: + llvm_unreachable("Unforeseen intrinsic"); // Can't reach here. + + case Intrinsic::vp_load: + visitLoadVP(VPIntrin); + return; + case Intrinsic::vp_store: + visitStoreVP(VPIntrin); + return; + case Intrinsic::vp_gather: + visitGatherVP(VPIntrin); + return; + case Intrinsic::vp_scatter: + visitScatterVP(VPIntrin); + return; + + case Intrinsic::vp_fcmp: + case Intrinsic::vp_icmp: + visitCmpVP(VPIntrin); + return; + + // Generic mappings +#define HANDLE_VP_TO_SDNODE(VPID, NODEID) \ + case Intrinsic::VPID: Opcode = ISD::NODEID; break; +#include "llvm/IR/VPIntrinsics.def" + } + + // TODO memory evl: SDValue Chain = getRoot(); + + SmallVector ValueVTs; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + ComputeValueVTs(TLI, DAG.getDataLayout(), VPIntrin.getType(), ValueVTs); + SDVTList VTs = DAG.getVTList(ValueVTs); + + // ValueVTs.push_back(MVT::Other); // Out chain + + // Request Operands + SmallVector OpValues; + auto ExceptPosOpt = VPIntrinsic::GetExceptionBehaviorParamPos(VPIntrin.getIntrinsicID()); + auto RoundingModePosOpt = VPIntrinsic::GetRoundingModeParamPos(VPIntrin.getIntrinsicID()); + for (int i = 0; i < (int) VPIntrin.getNumArgOperands(); ++i) { + if (ExceptPosOpt && (i == ExceptPosOpt.getValue())) continue; + if (RoundingModePosOpt && (i == RoundingModePosOpt.getValue())) continue; + OpValues.push_back(getValue(VPIntrin.getArgOperand(i))); + } + SDValue Result = DAG.getNode(Opcode, sdl, VTs, OpValues); + + + SDNodeFlags NodeFlags; + + // set exception flags where appropriate + NodeFlags.setNoFPExcept(!VPIntrin.isConstrainedOp()); + + // copy FMF where available + auto * FPIntrin = dyn_cast(&VPIntrin); + if (FPIntrin) NodeFlags.copyFMF(*FPIntrin); + + if (VPIntrin.isReductionOp()) { + NodeFlags.setVectorReduction(true); + } + + // Attach chain + SDValue VPResult; + if (Result.getNode()->getNumValues() == 2) { + SDValue OutChain = Result.getValue(1); + DAG.setRoot(OutChain); + VPResult = Result.getValue(0); + } else { + VPResult = Result; + } + + // attach flags and return + if (NodeFlags.isDefined()) VPResult.getNode()->setFlags(NodeFlags); + setValue(&VPIntrin, VPResult); +} + std::pair SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, const BasicBlock *EHPadBB) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -453,6 +453,11 @@ case ISD::VECREDUCE_UMIN: return "vecreduce_umin"; case ISD::VECREDUCE_FMAX: return "vecreduce_fmax"; case ISD::VECREDUCE_FMIN: return "vecreduce_fmin"; + + // Vector Predication +#define REGISTER_VP_SDNODE(NODEID,NAME,MASKPOS,VLENPOS) \ + case ISD::NODEID: return NAME; +#include "llvm/IR/VPIntrinsics.def" } } diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -673,6 +673,11 @@ // Instrument function entry and exit, e.g. with calls to mcount(). addPass(createPostInlineEntryExitInstrumenterPass()); + // Expand vector predication intrinsics into standard IR instructions. + // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction + // passes since it emits those kinds of intrinsics. + addPass(createExpandVectorPredicationPass()); + // Add scalarization of target's unsupported masked memory intrinsics pass. // the unsupported intrinsic will be replaced with a chain of basic blocks, // that stores/loads element one-by-one if the appropriate mask bit is set. diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp --- a/llvm/lib/IR/Attributes.cpp +++ b/llvm/lib/IR/Attributes.cpp @@ -330,6 +330,8 @@ return "builtin"; if (hasAttribute(Attribute::Convergent)) return "convergent"; + if (hasAttribute(Attribute::VectorLength)) + return "vlen"; if (hasAttribute(Attribute::SwiftError)) return "swifterror"; if (hasAttribute(Attribute::SwiftSelf)) @@ -346,6 +348,10 @@ return "inreg"; if (hasAttribute(Attribute::JumpTable)) return "jumptable"; + if (hasAttribute(Attribute::Mask)) + return "mask"; + if (hasAttribute(Attribute::Passthru)) + return "passthru"; if (hasAttribute(Attribute::MinSize)) return "minsize"; if (hasAttribute(Attribute::Naked)) diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -45,6 +45,7 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp SafepointIRVerifier.cpp ProfileSummary.cpp Statepoint.cpp @@ -52,10 +53,10 @@ TypeFinder.cpp Use.cpp User.cpp + VPBuilder.cpp Value.cpp ValueSymbolTable.cpp Verifier.cpp - ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/IR diff --git a/llvm/lib/IR/FPEnv.cpp b/llvm/lib/IR/FPEnv.cpp --- a/llvm/lib/IR/FPEnv.cpp +++ b/llvm/lib/IR/FPEnv.cpp @@ -12,8 +12,9 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/StringSwitch.h" #include "llvm/IR/FPEnv.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/Metadata.h" namespace llvm { @@ -75,4 +76,22 @@ return ExceptStr; } +Value *GetConstrainedFPExcept(LLVMContext &Context, + fp::ExceptionBehavior UseExcept) { + Optional ExceptStr = ExceptionBehaviorToStr(UseExcept); + assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); + auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); + + return MetadataAsValue::get(Context, ExceptMDS); } + +Value *GetConstrainedFPRounding(LLVMContext &Context, + fp::RoundingMode UseRounding) { + Optional RoundingStr = RoundingModeToStr(UseRounding); + assert(RoundingStr.hasValue() && "Garbage strict rounding mode!"); + auto *RoundingMDS = MDString::get(Context, RoundingStr.getValue()); + + return MetadataAsValue::get(Context, RoundingMDS); +} + +} // namespace llvm diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -521,6 +521,60 @@ return createCallHelper(TheFn, Ops, this, Name); } +/// Create a call to a vector-predicated intrinsic (VP). +/// \p OC - The LLVM IR Opcode of the operation +/// \p VecOpArray - Intrinsic operand list +/// \p FMFSource - Copy source for Fast Math Flags +/// \p Name - name of the result variable +Instruction *IRBuilderBase::CreateVectorPredicatedInst(unsigned OC, + ArrayRef Params, + Instruction *FMFSource, + const Twine &Name) { + + Module *M = BB->getParent()->getParent(); + + Intrinsic::ID VPID = VPIntrinsic::GetForOpcode(OC); + auto VPFunc = VPIntrinsic::GetDeclarationForParams(M, VPID, Params); + auto *VPCall = createCallHelper(VPFunc, Params, this, Name); + + // transfer fast math flags + if (FMFSource && isa(FMFSource)) { + VPCall->copyFastMathFlags(FMFSource); + } + + return VPCall; +} + +/// Create a call to a vector-predicated comparison intrinsic (VP). +/// \p Pred - comparison predicate +/// \p FirstOp - First vector operand +/// \p SndOp - Second vector operand +/// \p Mask - Mask operand +/// \p VectorLength - Vector length operand +/// \p Name - name of the result variable +Instruction *IRBuilderBase::CreateVectorPredicatedCmp( + CmpInst::Predicate Pred, Value *FirstParam, Value *SndParam, + Value *MaskParam, Value *VectorLengthParam, const Twine &Name) { + + Module *M = BB->getParent()->getParent(); + + // encode comparison predicate as MD + uint8_t RawPred = static_cast(Pred); + auto Int8Ty = Type::getInt8Ty(getContext()); + auto PredParam = ConstantInt::get(Int8Ty, RawPred, false); + + Intrinsic::ID VPID = FirstParam->getType()->isIntOrIntVectorTy() + ? Intrinsic::vp_icmp + : Intrinsic::vp_fcmp; + + auto VPFunc = VPIntrinsic::GetDeclarationForParams( + M, VPID, {FirstParam, SndParam, PredParam, MaskParam, VectorLengthParam}); + + return createCallHelper( + VPFunc, {FirstParam, SndParam, PredParam, MaskParam, VectorLengthParam}, + this, Name); +} + /// Create a call to a Masked Gather intrinsic. /// \p Ptrs - vector of pointers for loading /// \p Align - alignment for one element diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -21,13 +21,15 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Operator.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Operator.h" + #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -114,6 +116,7 @@ Optional ConstrainedFPIntrinsic::getExceptionBehavior() const { unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 1 && "underflow"); Metadata *MD = cast(getArgOperand(NumOperands - 1))->getMetadata(); if (!MD || !isa(MD)) @@ -178,36 +181,642 @@ } } +ElementCount VPIntrinsic::getStaticVectorLength() const { + auto GetVectorLengthOfType = [](const Type *T) -> ElementCount { + auto VT = cast(T); + auto ElemCount = VT->getElementCount(); + return ElemCount; + }; + + auto VPMask = getMaskParam(); + if (VPMask) { + return GetVectorLengthOfType(VPMask->getType()); + } + + // only compose does not have a mask param + assert(getIntrinsicID() == Intrinsic::vp_compose); + return GetVectorLengthOfType(getType()); +} + +void VPIntrinsic::setMaskParam(Value *NewMask) { + auto MaskPos = GetMaskParamPos(getIntrinsicID()); + assert(MaskPos.hasValue()); + this->setOperand(MaskPos.getValue(), NewMask); +} + +void VPIntrinsic::setVectorLengthParam(Value *NewVL) { + auto VLPos = GetVectorLengthParamPos(getIntrinsicID()); + assert(VLPos.hasValue()); + this->setOperand(VLPos.getValue(), NewVL); +} + +Value *VPIntrinsic::getMaskParam() const { + auto maskPos = GetMaskParamPos(getIntrinsicID()); + if (maskPos) + return getArgOperand(maskPos.getValue()); + return nullptr; +} + +Value *VPIntrinsic::getVectorLengthParam() const { + auto vlenPos = GetVectorLengthParamPos(getIntrinsicID()); + if (vlenPos) + return getArgOperand(vlenPos.getValue()); + return nullptr; +} + +Optional VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + +#define REGISTER_VP_INTRINSIC(VPID, MASKPOS, VLENPOS) \ + case Intrinsic::VPID: \ + return MASKPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Optional VPIntrinsic::GetVectorLengthParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + +#define REGISTER_VP_INTRINSIC(VPID, MASKPOS, VLENPOS) \ + case Intrinsic::VPID: \ + return VLENPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +bool VPIntrinsic::IsVPIntrinsic(Intrinsic::ID ID) { + switch (ID) { + default: + return false; + +#define REGISTER_VP_INTRINSIC(VPID, MASKPOS, VLENPOS) \ + case Intrinsic::VPID: \ + break; +#include "llvm/IR/VPIntrinsics.def" + } + return true; +} + +Intrinsic::ID VPIntrinsic::GetConstrainedIntrinsicForVP(Intrinsic::ID VPID) { + switch (VPID) { + default: + return Intrinsic::not_intrinsic; + +#define HANDLE_VP_TO_CONSTRAINED_INTRIN(VPID, CFPID) \ + case Intrinsic::VPID: \ + return Intrinsic::CFPID; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Intrinsic::ID VPIntrinsic::GetFunctionalIntrinsicForVP(Intrinsic::ID VPID) { + switch (VPID) { + default: + return Intrinsic::not_intrinsic; + +#define HANDLE_VP_TO_INTRIN(VPID, IID) \ + case Intrinsic::VPID: \ + return Intrinsic::IID; +#include "llvm/IR/VPIntrinsics.def" + } +} + +// Equivalent non-predicated opcode +unsigned VPIntrinsic::GetFunctionalOpcodeForVP(Intrinsic::ID ID) { + switch (ID) { + default: + return Instruction::Call; + +#define HANDLE_VP_TO_OC(VPID, OC) \ + case Intrinsic::VPID: \ + return Instruction::OC; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Intrinsic::ID VPIntrinsic::GetForOpcode(unsigned OC) { + switch (OC) { + default: + return Intrinsic::not_intrinsic; + +#define HANDLE_VP_TO_OC(VPID, OC) \ + case Instruction::OC: \ + return Intrinsic::VPID; +#include "llvm/IR/VPIntrinsics.def" + } +} + +bool VPIntrinsic::canIgnoreVectorLengthParam() const { + using namespace PatternMatch; + + ElementCount EC = getStaticVectorLength(); + + // No vlen param - no lanes masked-off by it. + auto *VLParam = getVectorLengthParam(); + if (!VLParam) + return true; + + // Note that the VP intrinsic causes undefined behavior if the Explicit Vector + // Length parameter is strictly greater-than the number of vector elements of + // the operation. This function returns true when this is detected statically + // in the IR. + + // Check whether "W == vscale * EC.Min" + if (EC.Scalable) { + // Undig the DL + auto ParMod = this->getModule(); + if (!ParMod) + return false; + const auto &DL = ParMod->getDataLayout(); + + // Compare vscale patterns + uint64_t ParamFactor; + if (EC.Min > 1 && + match(VLParam, m_c_BinOp(m_ConstantInt(ParamFactor), m_VScale(DL)))) { + return ParamFactor >= EC.Min; + } + if (match(VLParam, m_VScale(DL))) { + return ParamFactor; + } + return false; + } + + // standard SIMD operation + auto VLConst = dyn_cast(VLParam); + if (!VLConst) + return false; + + uint64_t VLNum = VLConst->getZExtValue(); + if (VLNum >= EC.Min) + return true; + + return false; +} + +CmpInst::Predicate VPIntrinsic::getCmpPredicate() const { + return static_cast( + cast(getArgOperand(2))->getZExtValue()); +} + +Optional VPIntrinsic::getRoundingMode() const { + auto RmParamPos = GetRoundingModeParamPos(getIntrinsicID()); + if (!RmParamPos) + return None; + + Metadata *MD = dyn_cast(getArgOperand(RmParamPos.getValue())) + ->getMetadata(); + if (!MD || !isa(MD)) + return None; + StringRef RoundingArg = cast(MD)->getString(); + return StrToRoundingMode(RoundingArg); +} + +Optional VPIntrinsic::getExceptionBehavior() const { + auto EbParamPos = GetExceptionBehaviorParamPos(getIntrinsicID()); + if (!EbParamPos) + return None; + + Metadata *MD = dyn_cast(getArgOperand(EbParamPos.getValue())) + ->getMetadata(); + if (!MD || !isa(MD)) + return None; + StringRef ExceptionArg = cast(MD)->getString(); + return StrToExceptionBehavior(ExceptionArg); +} + +/// \return The vector to reduce if this is a reduction operation. +Value *VPIntrinsic::getReductionVectorParam() const { + auto PosOpt = GetReductionVectorParamPos(getIntrinsicID()); + if (!PosOpt.hasValue()) + return nullptr; + return getArgOperand(PosOpt.getValue()); +} + +Optional VPIntrinsic::GetReductionVectorParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + +#define HANDLE_VP_REDUCTION(VPID, ACCUPOS, VECTORPOS) \ + case Intrinsic::VPID: \ + return VECTORPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +/// \return The accumulator initial value if this is a reduction operation. +Value *VPIntrinsic::getReductionAccuParam() const { + auto PosOpt = GetReductionAccuParamPos(getIntrinsicID()); + if (!PosOpt.hasValue()) + return nullptr; + return getArgOperand(PosOpt.getValue()); +} + +Optional VPIntrinsic::GetReductionAccuParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + +#define HANDLE_VP_REDUCTION(VPID, ACCUPOS, VECTORPOS) \ + case Intrinsic::VPID: \ + return ACCUPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +/// \return the alignment of the pointer used by this load/store/gather or +/// scatter. +MaybeAlign VPIntrinsic::getPointerAlignment() const { + Optional PtrParamOpt = GetMemoryPointerParamPos(getIntrinsicID()); + assert(PtrParamOpt.hasValue() && "no pointer argument!"); + unsigned AlignVal = this->getParamAlignment(PtrParamOpt.getValue()); + if (AlignVal) { + return MaybeAlign(AlignVal); + } + return None; +} + +/// \return The pointer operand of this load,store, gather or scatter. +Value *VPIntrinsic::getMemoryPointerParam() const { + auto PtrParamOpt = GetMemoryPointerParamPos(getIntrinsicID()); + if (!PtrParamOpt.hasValue()) + return nullptr; + return getArgOperand(PtrParamOpt.getValue()); +} + +Optional VPIntrinsic::GetMemoryPointerParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + +#define HANDLE_VP_IS_MEMOP(VPID, POINTERPOS, DATAPOS) \ + case Intrinsic::VPID: \ + return POINTERPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +/// \return The data (payload) operand of this store or scatter. +Value *VPIntrinsic::getMemoryDataParam() const { + auto DataParamOpt = GetMemoryDataParamPos(getIntrinsicID()); + if (!DataParamOpt.hasValue()) + return nullptr; + return getArgOperand(DataParamOpt.getValue()); +} + +Optional VPIntrinsic::GetMemoryDataParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + +#define HANDLE_VP_IS_MEMOP(VPID, POINTERPOS, DATAPOS) \ + case Intrinsic::VPID: \ + return DATAPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Function *VPIntrinsic::GetDeclarationForParams(Module *M, Intrinsic::ID VPID, + ArrayRef Params, + Type *VecRetTy) { + assert(VPID != Intrinsic::not_intrinsic && "todo dispatch to default insts"); + + bool IsArithOp = VPIntrinsic::IsBinaryVPOp(VPID) || + VPIntrinsic::IsUnaryVPOp(VPID) || + VPIntrinsic::IsTernaryVPOp(VPID); + bool IsCmpOp = (VPID == Intrinsic::vp_icmp) || (VPID == Intrinsic::vp_fcmp); + bool IsReduceOp = VPIntrinsic::IsVPReduction(VPID); + bool IsShuffleOp = + (VPID == Intrinsic::vp_compress) || (VPID == Intrinsic::vp_expand) || + (VPID == Intrinsic::vp_vshift) || (VPID == Intrinsic::vp_select) || + (VPID == Intrinsic::vp_compose); + bool IsMemoryOp = + (VPID == Intrinsic::vp_store) || (VPID == Intrinsic::vp_load) || + (VPID == Intrinsic::vp_store) || (VPID == Intrinsic::vp_load); + bool IsCastOp = + (VPID == Intrinsic::vp_fptosi) || (VPID == Intrinsic::vp_fptoui) || + (VPID == Intrinsic::vp_sitofp) || (VPID == Intrinsic::vp_uitofp) || + (VPID == Intrinsic::vp_fpext) || (VPID == Intrinsic::vp_fptrunc); + + Type *VecTy = nullptr; + Type *VecPtrTy = nullptr; + + if (IsArithOp || IsCmpOp || IsCastOp) { + Value &FirstOp = *Params[0]; + + // Fetch the VP intrinsic + VecTy = cast(FirstOp.getType()); + + } else if (IsReduceOp) { + auto VectorPosOpt = GetReductionVectorParamPos(VPID); + Value *VectorParam = Params[VectorPosOpt.getValue()]; + + VecTy = VectorParam->getType(); + + } else if (IsMemoryOp) { + auto DataPosOpt = VPIntrinsic::GetMemoryDataParamPos(VPID); + auto PtrPosOpt = VPIntrinsic::GetMemoryPointerParamPos(VPID); + VecPtrTy = Params[PtrPosOpt.getValue()]->getType(); + + if (DataPosOpt.hasValue()) { + // store-kind operation + VecTy = Params[DataPosOpt.getValue()]->getType(); + } else { + // load-kind operation + VecTy = VecPtrTy->getPointerElementType(); + } + + } else if (IsShuffleOp) { + VecTy = (VPID == Intrinsic::vp_select) ? Params[1]->getType() + : Params[0]->getType(); + } + + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto *VPFunc = Intrinsic::getDeclaration( + M, VPID, + VPIntrinsic::EncodeTypeTokens(TypeTokens, VecRetTy, VecPtrTy, *VecTy)); + assert(VPFunc && "not a VP intrinsic"); + + return VPFunc; +} + +VPIntrinsic::TypeTokenVec VPIntrinsic::GetTypeTokens(Intrinsic::ID ID) { + switch (ID) { + default: + llvm_unreachable("not implemented!"); + + case Intrinsic::vp_cos: + case Intrinsic::vp_sin: + case Intrinsic::vp_exp: + case Intrinsic::vp_exp2: + + case Intrinsic::vp_log: + case Intrinsic::vp_log2: + case Intrinsic::vp_log10: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_ceil: + case Intrinsic::vp_floor: + case Intrinsic::vp_round: + case Intrinsic::vp_trunc: + case Intrinsic::vp_rint: + case Intrinsic::vp_nearbyint: + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_udiv: + case Intrinsic::vp_srem: + case Intrinsic::vp_urem: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_vshift: + return TypeTokenVec{VPTypeToken::Vector}; + + case Intrinsic::vp_select: + return TypeTokenVec{VPTypeToken::Returned}; + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + return TypeTokenVec{VPTypeToken::Vector}; + + case Intrinsic::vp_gather: + case Intrinsic::vp_load: + return TypeTokenVec{VPTypeToken::Returned, VPTypeToken::Pointer}; + + case Intrinsic::vp_scatter: + case Intrinsic::vp_store: + return TypeTokenVec{VPTypeToken::Pointer, VPTypeToken::Vector}; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_sitofp: + case Intrinsic::vp_uitofp: + return TypeTokenVec{VPTypeToken::Returned, VPTypeToken::Vector}; + + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return TypeTokenVec{VPTypeToken::Vector}; + } +} + +bool VPIntrinsic::isReductionOp() const { + return IsVPReduction(getIntrinsicID()); +} + +bool VPIntrinsic::IsVPReduction(Intrinsic::ID ID) { + switch (ID) { + default: + return false; + +#define HANDLE_VP_REDUCTION(VPID, ACCUPOS, VECTORPOS) \ + case Intrinsic::VPID: \ + break; +#include "llvm/IR/VPIntrinsics.def" + } + + return true; +} + +bool VPIntrinsic::isConstrainedOp() const { + return (getRoundingMode() != None && + getRoundingMode() != fp::RoundingMode::rmToNearest) || + (getExceptionBehavior() != None && + getExceptionBehavior() != fp::ExceptionBehavior::ebIgnore); +} + +bool VPIntrinsic::isUnaryOp() const { return IsUnaryVPOp(getIntrinsicID()); } + +bool VPIntrinsic::IsUnaryVPOp(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + +#define HANDLE_VP_UNARYOP(VPID) \ + case Intrinsic::VPID: \ + return true; +#include "llvm/IR/VPIntrinsics.def" + } +} + +bool VPIntrinsic::isBinaryOp() const { return IsBinaryVPOp(getIntrinsicID()); } + +bool VPIntrinsic::IsBinaryVPOp(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + +#define HANDLE_VP_IS_BINARY(VPID) \ + case Intrinsic::VPID: \ + return true; +#include "llvm/IR/VPIntrinsics.def" + } +} + +bool VPIntrinsic::isTernaryOp() const { + return IsTernaryVPOp(getIntrinsicID()); +} + +bool VPIntrinsic::IsTernaryVPOp(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + +#define HANDLE_VP_IS_TERNARY(VPID) \ + case Intrinsic::VPID: \ + return true; +#include "llvm/IR/VPIntrinsics.def" + } +} + +bool VPIntrinsic::isCompareOp() const { + return IsCompareVPOp(getIntrinsicID()); +} + +bool VPIntrinsic::IsCompareVPOp(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + +#define HANDLE_VP_IS_XCMP(VPID) \ + case Intrinsic::VPID: \ + return true; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Optional +VPIntrinsic::GetExceptionBehaviorParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + +#define HANDLE_VP_FPCONSTRAINT(VPID, ROUNDPOS, EXCEPTPOS) \ + case Intrinsic::VPID: \ + return EXCEPTPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Optional VPIntrinsic::GetRoundingModeParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + +#define HANDLE_VP_FPCONSTRAINT(VPID, ROUNDPOS, EXCEPTPOS) \ + case Intrinsic::VPID: \ + return ROUNDPOS; +#include "llvm/IR/VPIntrinsics.def" + } +} + +Intrinsic::ID VPIntrinsic::GetForIntrinsic(Intrinsic::ID IntrinsicID) { + Optional ConstrainedID; + switch (IntrinsicID) { + default: + return Intrinsic::not_intrinsic; + +#define HANDLE_VP_TO_CONSTRAINED_INTRIN(VPID, CFPID) return Intrinsic::VPID; +#define HANDLE_VP_TO_INTRIN(VPID, IID) return Intrinsic::VPID; +#include "llvm/IR/VPIntrinsics.def" + } +} + +VPIntrinsic::ShortTypeVec +VPIntrinsic::EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type *VecRetTy, + Type *VecPtrTy, Type &VectorTy) { + ShortTypeVec STV; + + for (auto Token : TTVec) { + switch (Token) { + default: + llvm_unreachable("unsupported token"); // unsupported VPTypeToken + + case VPIntrinsic::VPTypeToken::Vector: + STV.push_back(&VectorTy); + break; + case VPIntrinsic::VPTypeToken::Pointer: + STV.push_back(VecPtrTy); + break; + case VPIntrinsic::VPTypeToken::Returned: + assert(VecRetTy); + STV.push_back(VecRetTy); + break; + case VPIntrinsic::VPTypeToken::Mask: + auto NumElems = VectorTy.getVectorNumElements(); + auto MaskTy = + VectorType::get(Type::getInt1Ty(VectorTy.getContext()), NumElems); + STV.push_back(MaskTy); + break; + } + } + + return STV; +} + Instruction::BinaryOps BinaryOpIntrinsic::getBinaryOp() const { switch (getIntrinsicID()) { - case Intrinsic::uadd_with_overflow: - case Intrinsic::sadd_with_overflow: - case Intrinsic::uadd_sat: - case Intrinsic::sadd_sat: - return Instruction::Add; - case Intrinsic::usub_with_overflow: - case Intrinsic::ssub_with_overflow: - case Intrinsic::usub_sat: - case Intrinsic::ssub_sat: - return Instruction::Sub; - case Intrinsic::umul_with_overflow: - case Intrinsic::smul_with_overflow: - return Instruction::Mul; - default: - llvm_unreachable("Invalid intrinsic"); + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_sat: + case Intrinsic::sadd_sat: + return Instruction::Add; + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_sat: + case Intrinsic::ssub_sat: + return Instruction::Sub; + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: + return Instruction::Mul; + default: + llvm_unreachable("Invalid intrinsic"); } } bool BinaryOpIntrinsic::isSigned() const { switch (getIntrinsicID()) { - case Intrinsic::sadd_with_overflow: - case Intrinsic::ssub_with_overflow: - case Intrinsic::smul_with_overflow: - case Intrinsic::sadd_sat: - case Intrinsic::ssub_sat: - return true; - default: - return false; + case Intrinsic::sadd_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::smul_with_overflow: + case Intrinsic::sadd_sat: + case Intrinsic::ssub_sat: + return true; + default: + return false; } } diff --git a/llvm/lib/IR/PredicatedInst.cpp b/llvm/lib/IR/PredicatedInst.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/PredicatedInst.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include +#include + +namespace { +using namespace llvm; +using ShortValueVec = SmallVector; +} // namespace + +namespace llvm { + +bool PredicatedInstruction::canIgnoreVectorLengthParam() const { + auto VPI = dyn_cast(this); + if (!VPI) + return true; + + return VPI->canIgnoreVectorLengthParam(); +} + +FastMathFlags PredicatedInstruction::getFastMathFlags() const { + return cast(this)->getFastMathFlags(); +} + +void PredicatedOperator::copyIRFlags(const Value *V, bool IncludeWrapFlags) { + auto *I = dyn_cast(this); + if (I) + I->copyIRFlags(V, IncludeWrapFlags); +} + +bool +PredicatedInstruction::isVectorReduction() const { + auto VPI = dyn_cast(this); + if (VPI) { + return VPI->isReductionOp(); + } + auto II = dyn_cast(this); + if (!II) return false; + + switch (II->getIntrinsicID()) { + default: + return false; + + case Intrinsic::experimental_vector_reduce_add: + case Intrinsic::experimental_vector_reduce_mul: + case Intrinsic::experimental_vector_reduce_and: + case Intrinsic::experimental_vector_reduce_or: + case Intrinsic::experimental_vector_reduce_xor: + case Intrinsic::experimental_vector_reduce_smin: + case Intrinsic::experimental_vector_reduce_smax: + case Intrinsic::experimental_vector_reduce_umin: + case Intrinsic::experimental_vector_reduce_umax: + case Intrinsic::experimental_vector_reduce_v2_fadd: + case Intrinsic::experimental_vector_reduce_v2_fmul: + case Intrinsic::experimental_vector_reduce_fmin: + case Intrinsic::experimental_vector_reduce_fmax: + return true; + } +} + +Instruction *PredicatedBinaryOperator::Create( + Module *Mod, Value *Mask, Value *VectorLen, Instruction::BinaryOps Opc, + Value *V1, Value *V2, const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + auto VPID = VPIntrinsic::GetForOpcode(Opc); + + // Default Code Path + if ((!Mod || (!Mask && !VectorLen)) || VPID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertAtEnd); + } else { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit VP Intrinsics"); + + // Fetch the VP intrinsic + auto &VecTy = cast(*V1->getType()); + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto *VPFunc = Intrinsic::getDeclaration( + Mod, VPID, + VPIntrinsic::EncodeTypeTokens(TypeTokens, &VecTy, nullptr, VecTy)); + + // Encode default environment fp behavior + LLVMContext &Ctx = V1->getContext(); + SmallVector BinOpArgs({V1, V2}); + if (VPIntrinsic::HasRoundingModeParam(VPID)) { + BinOpArgs.push_back( + GetConstrainedFPRounding(Ctx, fp::RoundingMode::rmToNearest)); + } + if (VPIntrinsic::HasExceptionBehaviorParam(VPID)) { + BinOpArgs.push_back( + GetConstrainedFPExcept(Ctx, fp::ExceptionBehavior::ebIgnore)); + } + + BinOpArgs.push_back(Mask); + BinOpArgs.push_back(VectorLen); + + CallInst *CI; + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertBefore); + } + + // the VP inst does not touch memory if the exception behavior is + // "fpecept.ignore" + CI->setDoesNotAccessMemory(); + return CI; +} + +} // namespace llvm diff --git a/llvm/lib/IR/VPBuilder.cpp b/llvm/lib/IR/VPBuilder.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/VPBuilder.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include +#include +#include + +namespace { +using namespace llvm; +using ShortTypeVec = VPIntrinsic::ShortTypeVec; +using ShortValueVec = SmallVector; +} // namespace + +namespace llvm { + +Module &VPBuilder::getModule() const { + return *Builder.GetInsertBlock()->getParent()->getParent(); +} + +Value &VPBuilder::RequestPred() { + if (Mask) + return *Mask; + + auto *boolTy = Builder.getInt1Ty(); + auto *maskTy = VectorType::get(boolTy, StaticVectorLength); + return *ConstantInt::getAllOnesValue(maskTy); +} + +Value &VPBuilder::RequestEVL() { + if (ExplicitVectorLength) + return *ExplicitVectorLength; + + auto *intTy = Builder.getInt32Ty(); + return *ConstantInt::get(intTy, StaticVectorLength); +} + +Value *VPBuilder::CreateVectorCopy(Instruction &Inst, ValArray VecOpArray) { + auto OC = Inst.getOpcode(); + auto VPID = VPIntrinsic::GetForOpcode(OC); + if (VPID == Intrinsic::not_intrinsic) { + return nullptr; + } + + Optional MaskPosOpt = VPIntrinsic::GetMaskParamPos(VPID); + Optional VLenPosOpt = VPIntrinsic::GetVectorLengthParamPos(VPID); + Optional FPRoundPosOpt = VPIntrinsic::GetRoundingModeParamPos(VPID); + Optional FPExceptPosOpt = + VPIntrinsic::GetExceptionBehaviorParamPos(VPID); + + Optional CmpPredPos = None; + if (isa(Inst)) { + CmpPredPos = 2; + } + + // TODO transfer alignment + + // construct VP vector operands (including pred and evl) + SmallVector VecParams; + for (size_t i = 0; i < Inst.getNumOperands() + 5; ++i) { + if (MaskPosOpt && (i == (size_t)MaskPosOpt.getValue())) { + // First operand of select is mask (singular exception) + if (VPID != Intrinsic::vp_select) + VecParams.push_back(&RequestPred()); + } + if (VLenPosOpt && (i == (size_t)VLenPosOpt.getValue())) { + VecParams.push_back(&RequestEVL()); + } + if (FPRoundPosOpt && (i == (size_t)FPRoundPosOpt.getValue())) { + // TODO decode fp env from constrained intrinsics + VecParams.push_back(GetConstrainedFPRounding( + Builder.getContext(), fp::RoundingMode::rmToNearest)); + } + if (FPExceptPosOpt && (i == (size_t)FPExceptPosOpt.getValue())) { + // TODO decode fp env from constrained intrinsics + VecParams.push_back(GetConstrainedFPExcept( + Builder.getContext(), fp::ExceptionBehavior::ebIgnore)); + } + if (CmpPredPos && (i == (size_t)CmpPredPos.getValue())) { + auto &CmpI = cast(Inst); + VecParams.push_back(ConstantInt::get( + Type::getInt8Ty(Builder.getContext()), CmpI.getPredicate())); + } + if (i < VecOpArray.size()) + VecParams.push_back(VecOpArray[i]); + } + + Type *ScaRetTy = Inst.getType(); + Type *VecRetTy = ScaRetTy->isVoidTy() ? ScaRetTy : &getVectorType(*ScaRetTy); + auto &M = *Builder.GetInsertBlock()->getParent()->getParent(); + auto VPDecl = + VPIntrinsic::GetDeclarationForParams(&M, VPID, VecParams, VecRetTy); + + return Builder.CreateCall(VPDecl, VecParams, Inst.getName() + ".vp"); +} + +VectorType &VPBuilder::getVectorType(Type &ElementTy) { + return *VectorType::get(&ElementTy, StaticVectorLength); +} + +Value &VPBuilder::CreateContiguousStore(Value &Val, Value &ElemPointer, + MaybeAlign AlignOpt) { + auto &PointerTy = cast(*ElemPointer.getType()); + auto &VecTy = getVectorType(*PointerTy.getPointerElementType()); + auto *VecPtrTy = VecTy.getPointerTo(PointerTy.getAddressSpace()); + auto *VecPtr = Builder.CreatePointerCast(&ElemPointer, VecPtrTy); + + auto *StoreFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_store, + {&VecTy, VecPtrTy}); + ShortValueVec Args{&Val, VecPtr, &RequestPred(), &RequestEVL()}; + CallInst &StoreCall = *Builder.CreateCall(StoreFunc, Args); + if (AlignOpt.hasValue()) { + unsigned PtrPos = + VPIntrinsic::GetMemoryPointerParamPos(Intrinsic::vp_store).getValue(); + StoreCall.addParamAttr( + PtrPos, Attribute::getWithAlignment(getContext(), AlignOpt.getValue())); + } + return StoreCall; +} + +Value &VPBuilder::CreateContiguousLoad(Value &ElemPointer, + MaybeAlign AlignOpt) { + auto &PointerTy = cast(*ElemPointer.getType()); + auto &VecTy = getVectorType(*PointerTy.getPointerElementType()); + auto *VecPtrTy = VecTy.getPointerTo(PointerTy.getAddressSpace()); + auto *VecPtr = Builder.CreatePointerCast(&ElemPointer, VecPtrTy); + + auto *LoadFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_load, + {&VecTy, VecPtrTy}); + ShortValueVec Args{VecPtr, &RequestPred(), &RequestEVL()}; + CallInst &LoadCall = *Builder.CreateCall(LoadFunc, Args); + if (AlignOpt.hasValue()) { + unsigned PtrPos = + VPIntrinsic::GetMemoryPointerParamPos(Intrinsic::vp_load).getValue(); + LoadCall.addParamAttr( + PtrPos, Attribute::getWithAlignment(getContext(), AlignOpt.getValue())); + } + return LoadCall; +} + +Value &VPBuilder::CreateScatter(Value &Val, Value &PointerVec, + MaybeAlign AlignOpt) { + auto *ScatterFunc = + Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_scatter, + {Val.getType(), PointerVec.getType()}); + ShortValueVec Args{&Val, &PointerVec, &RequestPred(), &RequestEVL()}; + CallInst &ScatterCall = *Builder.CreateCall(ScatterFunc, Args); + if (AlignOpt.hasValue()) { + unsigned PtrPos = + VPIntrinsic::GetMemoryPointerParamPos(Intrinsic::vp_scatter).getValue(); + ScatterCall.addParamAttr( + PtrPos, Attribute::getWithAlignment(getContext(), AlignOpt.getValue())); + } + return ScatterCall; +} + +Value &VPBuilder::CreateGather(Value &PointerVec, MaybeAlign AlignOpt) { + auto &PointerVecTy = cast(*PointerVec.getType()); + auto &ElemTy = *cast(*PointerVecTy.getVectorElementType()) + .getPointerElementType(); + auto &VecTy = *VectorType::get(&ElemTy, PointerVecTy.getNumElements()); + auto *GatherFunc = Intrinsic::getDeclaration( + &getModule(), Intrinsic::vp_gather, {&VecTy, &PointerVecTy}); + + ShortValueVec Args{&PointerVec, &RequestPred(), &RequestEVL()}; + CallInst &GatherCall = *Builder.CreateCall(GatherFunc, Args); + if (AlignOpt.hasValue()) { + unsigned PtrPos = + VPIntrinsic::GetMemoryPointerParamPos(Intrinsic::vp_gather).getValue(); + GatherCall.addParamAttr( + PtrPos, Attribute::getWithAlignment(getContext(), AlignOpt.getValue())); + } + return GatherCall; +} + +Value *VPBuilder::CreateVectorShift(Value *SrcVal, Value *Amount, Twine Name) { + auto D = VPIntrinsic::GetDeclarationForParams( + &getModule(), Intrinsic::vp_vshift, {SrcVal, Amount}); + return Builder.CreateCall(D, {SrcVal, Amount, &RequestPred(), &RequestEVL()}, + Name); +} + +} // namespace llvm diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -93,6 +93,7 @@ #include "llvm/IR/ModuleSlotTracker.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Statepoint.h" +#include "llvm/IR/FPEnv.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -477,6 +478,7 @@ void visitUserOp2(Instruction &I) { visitUserOp1(I); } void visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call); void visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI); + void visitVPIntrinsic(VPIntrinsic &FPI); void visitDbgIntrinsic(StringRef Kind, DbgVariableIntrinsic &DII); void visitDbgLabelIntrinsic(StringRef Kind, DbgLabelInst &DLI); void visitAtomicCmpXchgInst(AtomicCmpXchgInst &CXI); @@ -1705,11 +1707,14 @@ if (Attrs.isEmpty()) return; + bool SawMask = false; bool SawNest = false; + bool SawPassthru = false; bool SawReturned = false; bool SawSRet = false; bool SawSwiftSelf = false; bool SawSwiftError = false; + bool SawVectorLength = false; // Verify return value attributes. AttributeSet RetAttrs = Attrs.getRetAttributes(); @@ -1778,12 +1783,33 @@ SawSwiftError = true; } + if (ArgAttrs.hasAttribute(Attribute::VectorLength)) { + Assert(!SawVectorLength, "Cannot have multiple 'vlen' parameters!", + V); + SawVectorLength = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Passthru)) { + Assert(!SawPassthru, "Cannot have multiple 'passthru' parameters!", + V); + SawPassthru = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Mask)) { + Assert(!SawMask, "Cannot have multiple 'mask' parameters!", + V); + SawMask = true; + } + if (ArgAttrs.hasAttribute(Attribute::InAlloca)) { Assert(i == FT->getNumParams() - 1, "inalloca isn't on the last parameter!", V); } } + Assert(!SawPassthru || SawMask, + "Cannot have 'passthru' parameter without 'mask' parameter!", V); + if (!Attrs.hasAttributes(AttributeList::FunctionIndex)) return; @@ -4366,6 +4392,13 @@ #include "llvm/IR/ConstrainedOps.def" visitConstrainedFPIntrinsic(cast(Call)); break; + +#define REGISTER_VP_INTRINSIC(VPID,MASKPOS,VLENPOS) \ + case Intrinsic::VPID: +#include "llvm/IR/VPIntrinsics.def" + visitVPIntrinsic(cast(Call)); + break; + case Intrinsic::dbg_declare: // llvm.dbg.declare Assert(isa(Call.getArgOperand(0)), "invalid llvm.dbg.declare intrinsic call 1", Call); @@ -4811,6 +4844,18 @@ return nullptr; } +void Verifier::visitVPIntrinsic(VPIntrinsic &VPI) { + Assert(!VPI.isConstrainedOp(), + "VP intrinsics only support the default fp environment for now " + "(round.tonearest; fpexcept.ignore)."); + if (VPI.isConstrainedOp()) { + Assert(VPI.getExceptionBehavior() != None, + "invalid exception behavior argument", &VPI); + Assert(VPI.getRoundingMode() != None, "invalid rounding mode argument", + &VPI); + } +} + void Verifier::visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI) { unsigned NumOperands; bool HasRoundingMD; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -24,6 +24,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/VPBuilder.h" +#include "llvm/IR/MatcherCast.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AlignOf.h" @@ -2054,29 +2057,33 @@ /// This eliminates floating-point negation in either 'fneg(X)' or /// 'fsub(-0.0, X)' form by combining into a constant operand. +template static Instruction *foldFNegIntoConstant(Instruction &I) { Value *X; Constant *C; + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Fold negation into constant operand. This is limited with one-use because // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. - // -(X * C) --> X * (-C) // FIXME: It's arguable whether these should be m_OneUse or not. The current // belief is that the FNeg allows for better reassociation opportunities. - if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + // -(X * C) --> X * (-C) + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); // -(X / C) --> X / (-C) - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); // -(C / X) --> (-C) / X - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) - return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return MCBuilder.CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: // -(X + C) --> -X + -C --> -C - X if (I.hasNoSignedZeros() && - match(&I, m_FNeg(m_OneUse(m_FAdd(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + MC.try_match(&I, m_FNeg(m_OneUse(m_FAdd(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); return nullptr; } @@ -2104,7 +2111,7 @@ SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I)) return X; Value *X, *Y; @@ -2120,6 +2127,17 @@ return nullptr; } +Instruction *InstCombiner::visitPredicatedFSub(PredicatedBinaryOperator& I) { + auto * Inst = cast(&I); + PredicatedContext PC(&I); + if (Value *V = SimplifyPredicatedFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(Inst), PC)) + return replaceInstUsesWith(*Inst, V); + + return visitFSubGeneric(*Inst); +} + Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -2129,6 +2147,14 @@ if (Instruction *X = foldVectorBinop(I)) return X; + return visitFSubGeneric(I); +} + +template +Instruction *InstCombiner::visitFSubGeneric(BinaryOpTy &I) { + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Subtraction from -0.0 is the canonical form of fneg. // fsub -0.0, X ==> fneg X // fsub nsz 0.0, X ==> fneg nsz X @@ -2137,10 +2163,10 @@ // fsub -0.0, Denorm ==> +-0 // fneg Denorm ==> -Denorm Value *Op; - if (match(&I, m_FNeg(m_Value(Op)))) - return UnaryOperator::CreateFNegFMF(Op, &I); + if (MC.try_match(&I, m_FNeg(m_Value(Op)))) + return MCBuilder.CreateFNegFMF(Op, &I); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I)) return X; if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) @@ -2157,17 +2183,17 @@ // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { - if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { - Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); - return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + if (MC.try_match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = MCBuilder.CreateFSubFMF(Builder, Y, X, &I); + return MCBuilder.CreateFAddFMF(Op0, NewSub, &I); } } // (-X) - Op1 --> -(X + Op1) if (I.hasNoSignedZeros() && !isa(Op0) && - match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { - Value *FAdd = Builder.CreateFAddFMF(X, Op1, &I); - return UnaryOperator::CreateFNegFMF(FAdd, &I); + MC.try_match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { + Value *FAdd = MCBuilder.CreateFAddFMF(Builder, X, Op1, &I); + return MCBuilder.CreateFNegFMF(FAdd, &I); } if (isa(Op0)) @@ -2178,22 +2204,22 @@ // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_Constant(C)) && !isa(Op1)) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(Op1, m_Constant(C)) && !isa(Op1)) + return MCBuilder.CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + if (MC.try_match(Op1, m_FNeg(m_Value(Y)))) + return MCBuilder.CreateFAddFMF(Op0, Y, &I); // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) Type *Ty = I.getType(); - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPTrunc(Builder, Y, Ty), &I); // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPExt(Builder, Y, Ty), &I); // Similar to above, but look through fmul/fdiv of the negated value: // Op0 - (-X * Y) --> Op0 + (X * Y) @@ -2211,32 +2237,35 @@ } // Handle special cases for FSub with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + if (auto * PlainBinOp = dyn_cast(&I)) + if (Value *V = SimplifySelectsFeedingBinaryOp(*PlainBinOp, Op0, Op1)) + return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X - if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return UnaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X - if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return UnaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) - if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + if (MC.try_match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); - return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + return MCBuilder.CreateFMulFMF(Op1, CSubOne, &I); } // X - (X * C) --> X * (1.0 - C) - if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + if (MC.try_match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); - return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + return MCBuilder.CreateFMulFMF(Op0, OneSubC, &I); } - if (Instruction *F = factorizeFAddFSub(I, Builder)) - return F; + if (auto * PlainBinOp = dyn_cast(&I)) { + if (Instruction *F = factorizeFAddFSub(*PlainBinOp, Builder)) + return F; + } // TODO: This performs reassociative folds for FP ops. Some fraction of the // functionality has been subsumed by simple pattern matching here and in diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -39,6 +39,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsX86.h" @@ -1803,6 +1804,14 @@ return &CI; } + // Predicated instruction patterns + auto * VPInst = dyn_cast(&CI); + if (VPInst) { + auto * PredInst = cast(VPInst); + auto Result = visitPredicatedInstruction(PredInst); + if (Result) return Result; + } + IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallBase(CI); @@ -1867,7 +1876,8 @@ if (Changed) return II; } - // For vector result intrinsics, use the generic demanded vector support. + // For vector result intrinsics, use the generic demanded vector support to + // simplify any operands before moving on to the per-intrinsic rules. if (II->getType()->isVectorTy()) { auto VWidth = II->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -30,6 +30,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" @@ -372,6 +373,8 @@ Value *OptimizePointerDifference( Value *LHS, Value *RHS, Type *Ty, bool isNUW); Instruction *visitSub(BinaryOperator &I); + template Instruction *visitFSubGeneric(BinaryOpTy &I); + Instruction *visitPredicatedFSub(PredicatedBinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); @@ -448,6 +451,16 @@ Instruction *visitVAEndInst(VAEndInst &I); Instruction *visitFreeze(FreezeInst &I); + // Entry point to VPIntrinsic + Instruction *visitPredicatedInstruction(PredicatedInstruction * PI) { + switch (PI->getOpcode()) { + default: + return nullptr; + case Instruction::FSub: + return visitPredicatedFSub(cast(*PI)); + } + } + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -863,6 +863,7 @@ case Attribute::InaccessibleMemOnly: case Attribute::InaccessibleMemOrArgMemOnly: case Attribute::JumpTable: + case Attribute::Mask: case Attribute::Naked: case Attribute::Nest: case Attribute::NoAlias: @@ -872,6 +873,7 @@ case Attribute::NoSync: case Attribute::None: case Attribute::NonNull: + case Attribute::Passthru: case Attribute::ReadNone: case Attribute::ReadOnly: case Attribute::Returned: @@ -883,6 +885,7 @@ case Attribute::SwiftError: case Attribute::SwiftSelf: case Attribute::WillReturn: + case Attribute::VectorLength: case Attribute::WriteOnly: case Attribute::ZExt: case Attribute::ImmArg: diff --git a/llvm/test/Bitcode/attributes.ll b/llvm/test/Bitcode/attributes.ll --- a/llvm/test/Bitcode/attributes.ll +++ b/llvm/test/Bitcode/attributes.ll @@ -374,6 +374,11 @@ ret void; } +; CHECK: define <8 x double> @f64(<8 x double> passthru %0, <8 x i1> mask %1, i32 vlen %2) { +define <8 x double> @f64(<8 x double> passthru, <8 x i1> mask, i32 vlen) { + ret <8 x double> undef +} + ; CHECK: attributes #0 = { noreturn } ; CHECK: attributes #1 = { nounwind } ; CHECK: attributes #2 = { readnone } diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll @@ -25,6 +25,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: AArch64 Stack Tagging diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -50,6 +50,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll --- a/llvm/test/CodeGen/ARM/O3-pipeline.ll +++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll @@ -35,6 +35,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/Generic/expand-vp.ll b/llvm/test/CodeGen/Generic/expand-vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/expand-vp.ll @@ -0,0 +1,245 @@ +; RUN: opt --expand-vec-pred -S < %s | FileCheck %s + +define void @test_vp_int(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.add}} +; CHECK-NOT: {{call.* @llvm.vp.sub}} +; CHECK-NOT: {{call.* @llvm.vp.mul}} +; CHECK-NOT: {{call.* @llvm.vp.sdiv}} +; CHECK-NOT: {{call.* @llvm.vp.srem}} +; CHECK-NOT: {{call.* @llvm.vp.udiv}} +; CHECK-NOT: {{call.* @llvm.vp.urem}} +; CHECK-NOT: {{call.* @llvm.vp.and}} +; CHECK-NOT: {{call.* @llvm.vp.or}} +; CHECK-NOT: {{call.* @llvm.vp.xor}} +; CHECK-NOT: {{call.* @llvm.vp.ashr}} +; CHECK-NOT: {{call.* @llvm.vp.lshr}} +; CHECK-NOT: {{call.* @llvm.vp.shl}} + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_constrainedfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.reduce}} +; CHECK-NOT: {{call.* @llvm.vp.fadd}} +; CHECK-NOT: {{call.* @llvm.vp.fsub}} +; CHECK-NOT: {{call.* @llvm.vp.fmul}} +; CHECK-NOT: {{call.* @llvm.vp.frem}} +; CHECK-NOT: {{call.* @llvm.vp.fma}} +; CHECK-NOT: {{call.* @llvm.vp.fneg}} + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.fmul.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.fdiv.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.frem.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.fma.v8f64(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.fneg.v8f64(<8 x double> %f2, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r7 = call <8 x double> @llvm.vp.minnum.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r8 = call <8 x double> @llvm.vp.maxnum.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_fpcast(<8 x double> %x, <8 x i64> %y, <8 x float> %z, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.fptosi}} +; CHECK-NOT: {{call.* @llvm.vp.fptoui}} +; CHECK-NOT: {{call.* @llvm.vp.sitofp}} +; CHECK-NOT: {{call.* @llvm.vp.uitofp}} +; CHECK-NOT: {{call.* @llvm.vp.rint}} +; CHECK-NOT: {{call.* @llvm.vp.round}} +; CHECK-NOT: {{call.* @llvm.vp.nearbyint}} +; CHECK-NOT: {{call.* @llvm.vp.ceil}} +; CHECK-NOT: {{call.* @llvm.vp.floor}} +; CHECK-NOT: {{call.* @llvm.vp.trunc}} +; CHECK-NOT: {{call.* @llvm.vp.fptrunc}} +; CHECK-NOT: {{call.* @llvm.vp.fpext}} + %r0 = call <8 x i64> @llvm.vp.fptosi.v8i64v8f64(<8 x double> %x, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x i64> @llvm.vp.fptoui.v8i64v8f64(<8 x double> %x, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.sitofp.v8f64v8i64(<8 x i64> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.uitofp.v8f64v8i64(<8 x i64> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.rint.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r7 = call <8 x double> @llvm.vp.round.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rA = call <8 x double> @llvm.vp.nearbyint.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rB = call <8 x double> @llvm.vp.ceil.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rC = call <8 x double> @llvm.vp.floor.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rD = call <8 x double> @llvm.vp.trunc.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rE = call <8 x float> @llvm.vp.fptrunc.v8f32v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rF = call <8 x double> @llvm.vp.fpext.v8f64v8f32(<8 x float> %z, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_fpfuncs(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.pow}} +; CHECK-NOT: {{call.* @llvm.vp.sqrt}} +; CHECK-NOT: {{call.* @llvm.vp.sin}} +; CHECK-NOT: {{call.* @llvm.vp.cos}} +; CHECK-NOT: {{call.* @llvm.vp.log}} +; CHECK-NOT: {{call.* @llvm.vp.exp\.}} +; CHECK-NOT: {{call.* @llvm.vp.exp2}} + %r0 = call <8 x double> @llvm.vp.pow.v8f64(<8 x double> %x, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.powi.v8f64(<8 x double> %x, i32 %n, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.sqrt.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.sin.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.cos.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.log.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.log10.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r7 = call <8 x double> @llvm.vp.log2.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r8 = call <8 x double> @llvm.vp.exp.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r9 = call <8 x double> @llvm.vp.exp2.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_mem(<16 x i32*> %p0, <16 x i32>* %p1, <16 x i32> %i0, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.load}} +; CHECK-NOT: {{call.* @llvm.vp.store}} +; CHECK-NOT: {{call.* @llvm.vp.gather}} +; CHECK-NOT: {{call.* @llvm.vp.scatter}} + call void @llvm.vp.store.v16i32.p0v16i32(<16 x i32> %i0, <16 x i32>* %p1, <16 x i1> %m, i32 %n) + call void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32> %i0 , <16 x i32*> %p0, <16 x i1> %m, i32 %n) + %l0 = call <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>* %p1, <16 x i1> %m, i32 %n) + %l1 = call <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*> %p0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_fp(<16 x float> %v, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.reduce.fadd}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.fmul}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.fmin}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.fmax}} + %r0 = call float @llvm.vp.reduce.fadd.v16f32(float 0.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r1 = call float @llvm.vp.reduce.fmul.v16f32(float 42.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r2 = call float @llvm.vp.reduce.fmin.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + %r3 = call float @llvm.vp.reduce.fmax.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_int(<16 x i32> %v, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.reduce}} + %r0 = call i32 @llvm.vp.reduce.add.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r1 = call i32 @llvm.vp.reduce.mul.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r2 = call i32 @llvm.vp.reduce.and.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r3 = call i32 @llvm.vp.reduce.xor.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r4 = call i32 @llvm.vp.reduce.or.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r5 = call i32 @llvm.vp.reduce.smin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r6 = call i32 @llvm.vp.reduce.smax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r7 = call i32 @llvm.vp.reduce.umin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r8 = call i32 @llvm.vp.reduce.umax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_shuffle(<16 x float> %v0, <16 x float> %v1, <16 x i1> %m, i32 %k, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.select}} +; CHECK-NOT: {{call.* @llvm.vp.compose}} +; no generic lowering available: {{call.* @llvm.vp.compress}} +; no generic lowering available: {{call.* @llvm.vp.expand}} +; CHECK-NOT: {{call.* @llvm.vp.vshift}} + %r0 = call <16 x float> @llvm.vp.select.v16f32(<16 x i1> %m, <16 x float> %v0, <16 x float> %v1, i32 %n) + %r1 = call <16 x float> @llvm.vp.compose.v16f32(<16 x float> %v0, <16 x float> %v1, i32 %k, i32 %n) + %r2 = call <16 x float> @llvm.vp.vshift.v16f32(<16 x float> %v0, i32 7, <16 x i1> %m, i32 %n) + %r3 = call <16 x float> @llvm.vp.compress.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + %r4 = call <16 x float> @llvm.vp.expand.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_xcmp(<16 x i32> %i0, <16 x i32> %i1, <16 x float> %f0, <16 x float> %f1, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.icmp}} +; CHECK-NOT: {{call.* @llvm.vp.fcmp}} + %r0 = call <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32> %i0, <16 x i32> %i1, i8 38, <16 x i1> %m, i32 %n) + %r1 = call <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float> %f0, <16 x float> %f1, i8 10, <16 x i1> %m, i32 %n) + ret void +} + +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) + +; floating point arith +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.minnum.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.maxnum.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +; cast & conversions +declare <8 x i64> @llvm.vp.fptosi.v8i64v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) +declare <8 x i64> @llvm.vp.fptoui.v8i64v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.sitofp.v8f64v8i64(<8 x i64>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.uitofp.v8f64v8i64(<8 x i64>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.rint.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.round.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.nearbyint.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.ceil.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.floor.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.trunc.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x float> @llvm.vp.fptrunc.v8f32v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fpext.v8f64v8f32(<8 x float> %x, metadata, <8 x i1> mask, i32 vlen) + +; math ops +declare <8 x double> @llvm.vp.pow.v8f64(<8 x double> %x, <8 x double> %y, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.powi.v8f64(<8 x double> %x, i32 %y, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.sqrt.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.sin.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.cos.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.log.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.log10.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.log2.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.exp.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.exp2.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) + +; memory +declare void @llvm.vp.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) +declare void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) + +; reductions +declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.smax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) +declare i32 @llvm.vp.reduce.smin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) +declare i32 @llvm.vp.reduce.umax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) +declare i32 @llvm.vp.reduce.umin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + +; shuffles +declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) +declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) +declare <16 x float> @llvm.vp.vshift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) + +; icmp , fcmp +declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) +declare <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float>, <16 x float>, i8, <16 x i1> mask, i32 vlen) diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -28,6 +28,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Expand indirectbr instructions diff --git a/llvm/test/CodeGen/X86/O3-pipeline.ll b/llvm/test/CodeGen/X86/O3-pipeline.ll --- a/llvm/test/CodeGen/X86/O3-pipeline.ll +++ b/llvm/test/CodeGen/X86/O3-pipeline.ll @@ -47,6 +47,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/Transforms/InstCombine/vp-fsub.ll b/llvm/test/Transforms/InstCombine/vp-fsub.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vp-fsub.ll @@ -0,0 +1,45 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; PR4374 + +define <4 x float> @test1_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @test1_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; Can't do anything with the test above because -0.0 - 0.0 = -0.0, but if we have nsz: +; -(X - Y) --> Y - X + +; TODO predicated FAdd folding +define <4 x float> @neg_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CH***-LABEL: @neg_sub_nsz_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; With nsz: Z - (X - Y) --> Z + (Y - X) + +define <4 x float> @sub_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @sub_sub_nsz_vp( +; CHECK-NEXT: %1 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %y, <4 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) # +; CHECK-NEXT: %t2 = call nsz <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %z, <4 x float> %1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) # +; CHECK-NEXT: ret <4 x float> %t2 + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %z, <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + + + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, metadata, metadata, <4 x i1> mask, i32 vlen) + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, metadata, metadata, <4 x i1> mask, i32 vlen) + +attributes #0 = { readnone } diff --git a/llvm/test/Transforms/InstSimplify/vp-fsub.ll b/llvm/test/Transforms/InstSimplify/vp-fsub.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/vp-fsub.ll @@ -0,0 +1,55 @@ +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define <8 x double> @fsub_fadd_fold_vp_xy(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_xy +; CHECK: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %x, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res0 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res0 +} + +define <8 x double> @fsub_fadd_fold_vp_zw(<8 x double> %z, <8 x double> %w, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_zw +; CHECK: ret <8 x double> %z + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %w, <8 x double> %z, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res1 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %w, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res1 +} + +; REQUIRES-CONSTRAINED-VP: define <8 x double> @fsub_fadd_fold_vp_yx_fpexcept(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) #0 { +; REQUIRES-CONSTRAINED-VP: ; *HECK-LABEL: fsub_fadd_fold_vp_yx +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: %tmp = +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: %res2 = +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: ret +; REQUIRES-CONSTRAINED-VP: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %len) +; REQUIRES-CONSTRAINED-VP: %res2 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %len) +; REQUIRES-CONSTRAINED-VP: ret <8 x double> %res2 +; REQUIRES-CONSTRAINED-VP: } + +define <8 x double> @fsub_fadd_fold_vp_yx_olen(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, i32 %otherLen) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_olen +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %otherLen) +; CHECK-NEXT: %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) +; CHECK-NEXT: ret <8 x double> %res3 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %otherLen) + %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res3 +} + +define <8 x double> @fsub_fadd_fold_vp_yx_omask(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, <8 x i1> %othermask) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_omask +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) +; CHECK-NEXT: %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %othermask, i32 %len) +; CHECK-NEXT: ret <8 x double> %res4 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %othermask, i32 %len) + ret <8 x double> %res4 +} + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +attributes #0 = { strictfp } diff --git a/llvm/test/Verifier/vp-intrinsics-constrained.ll b/llvm/test/Verifier/vp-intrinsics-constrained.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/vp-intrinsics-constrained.ll @@ -0,0 +1,17 @@ +; RUN: not opt -S < %s |& FileCheck %s +; CHECK: VP intrinsics only support the default fp environment for now (round.tonearest; fpexcept.ignore). +; CHECK: error: input module is broken! + +define void @test_vp_strictfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) #0 { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_rounding(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) #0 { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +attributes #0 = { strictfp } diff --git a/llvm/test/Verifier/vp-intrinsics.ll b/llvm/test/Verifier/vp-intrinsics.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/vp-intrinsics.ll @@ -0,0 +1,190 @@ +; RUN: opt --verify %s + +define void @test_vp_int(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_constrainedfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.fmul.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.fdiv.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.frem.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.fma.v8f64(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.fneg.v8f64(<8 x double> %f2, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r7 = call <8 x double> @llvm.vp.minnum.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r8 = call <8 x double> @llvm.vp.maxnum.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_fpcast(<8 x double> %x, <8 x i64> %y, <8 x float> %z, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i64> @llvm.vp.fptosi.v8i64v8f64(<8 x double> %x, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x i64> @llvm.vp.fptoui.v8i64v8f64(<8 x double> %x, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.sitofp.v8f64v8i64(<8 x i64> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.uitofp.v8f64v8i64(<8 x i64> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.rint.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r7 = call <8 x double> @llvm.vp.round.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rA = call <8 x double> @llvm.vp.nearbyint.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rB = call <8 x double> @llvm.vp.ceil.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rC = call <8 x double> @llvm.vp.floor.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rD = call <8 x double> @llvm.vp.trunc.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rE = call <8 x float> @llvm.vp.fptrunc.v8f32v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %rF = call <8 x double> @llvm.vp.fpext.v8f64v8f32(<8 x float> %z, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_fpfuncs(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %n) { + %r0 = call <8 x double> @llvm.vp.pow.v8f64(<8 x double> %x, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.powi.v8f64(<8 x double> %x, i32 %n, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.sqrt.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.sin.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.cos.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.log.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.log10.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r7 = call <8 x double> @llvm.vp.log2.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r8 = call <8 x double> @llvm.vp.exp.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r9 = call <8 x double> @llvm.vp.exp2.v8f64(<8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_mem(<16 x i32*> %p0, <16 x i32>* %p1, <16 x i32> %i0, <16 x i1> %m, i32 %n) { + call void @llvm.vp.store.v16i32.p0v16i32(<16 x i32> %i0, <16 x i32>* %p1, <16 x i1> %m, i32 %n) + call void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32> %i0 , <16 x i32*> %p0, <16 x i1> %m, i32 %n) + %l0 = call <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>* %p1, <16 x i1> %m, i32 %n) + %l1 = call <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*> %p0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_fp(<16 x float> %v, <16 x i1> %m, i32 %n) { + %r0 = call float @llvm.vp.reduce.fadd.v16f32(float 0.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r1 = call float @llvm.vp.reduce.fmul.v16f32(float 42.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r2 = call float @llvm.vp.reduce.fmin.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + %r3 = call float @llvm.vp.reduce.fmax.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_int(<16 x i32> %v, <16 x i1> %m, i32 %n) { + %r0 = call i32 @llvm.vp.reduce.add.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r1 = call i32 @llvm.vp.reduce.mul.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r2 = call i32 @llvm.vp.reduce.and.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r3 = call i32 @llvm.vp.reduce.xor.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r4 = call i32 @llvm.vp.reduce.or.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r5 = call i32 @llvm.vp.reduce.smin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r6 = call i32 @llvm.vp.reduce.smax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r7 = call i32 @llvm.vp.reduce.umin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r8 = call i32 @llvm.vp.reduce.umax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_shuffle(<16 x float> %v0, <16 x float> %v1, <16 x i1> %m, i32 %k, i32 %n) { + %r0 = call <16 x float> @llvm.vp.select.v16f32(<16 x i1> %m, <16 x float> %v0, <16 x float> %v1, i32 %n) + %r1 = call <16 x float> @llvm.vp.compose.v16f32(<16 x float> %v0, <16 x float> %v1, i32 %k, i32 %n) + %r2 = call <16 x float> @llvm.vp.vshift.v16f32(<16 x float> %v0, i32 %k, <16 x i1> %m, i32 %n) + %r3 = call <16 x float> @llvm.vp.compress.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + %r4 = call <16 x float> @llvm.vp.expand.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_xcmp(<16 x i32> %i0, <16 x i32> %i1, <16 x float> %f0, <16 x float> %f1,<16 x i1> %m, i32 %n) { + %r0 = call <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32> %i0, <16 x i32> %i1, i8 38, <16 x i1> %m, i32 %n) + %r1 = call <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float> %f0, <16 x float> %f1, i8 10, <16 x i1> %m, i32 %n) + ret void +} + +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) + +; floating point arith +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.minnum.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.maxnum.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +; cast & conversions +declare <8 x i64> @llvm.vp.fptosi.v8i64v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) +declare <8 x i64> @llvm.vp.fptoui.v8i64v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.sitofp.v8f64v8i64(<8 x i64>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.uitofp.v8f64v8i64(<8 x i64>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.rint.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.round.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.nearbyint.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.ceil.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.floor.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.trunc.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x float> @llvm.vp.fptrunc.v8f32v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fpext.v8f64v8f32(<8 x float> %x, metadata, <8 x i1> mask, i32 vlen) + +; math ops +declare <8 x double> @llvm.vp.pow.v8f64(<8 x double> %x, <8 x double> %y, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.powi.v8f64(<8 x double> %x, i32 %y, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.sqrt.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.sin.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.cos.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.log.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.log10.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.log2.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.exp.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.exp2.v8f64(<8 x double> %x, metadata, metadata, <8 x i1> mask, i32 vlen) + +; memory +declare void @llvm.vp.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) +declare void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) + +; reductions +declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.smax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) +declare i32 @llvm.vp.reduce.smin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) +declare i32 @llvm.vp.reduce.umax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) +declare i32 @llvm.vp.reduce.umin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + +; shuffles +declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) +declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) +declare <16 x float> @llvm.vp.vshift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) + +; icmp , fcmp +declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) +declare <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float>, <16 x float>, i8, <16 x i1> mask, i32 vlen) diff --git a/llvm/test/Verifier/vp_attributes.ll b/llvm/test/Verifier/vp_attributes.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/vp_attributes.ll @@ -0,0 +1,13 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +declare void @a(<16 x i1> mask %a, <16 x i1> mask %b) +; CHECK: Cannot have multiple 'mask' parameters! + +declare void @b(<16 x i1> mask %a, i32 vlen %x, i32 vlen %y) +; CHECK: Cannot have multiple 'vlen' parameters! + +declare <16 x double> @c(<16 x double> passthru %a) +; CHECK: Cannot have 'passthru' parameter without 'mask' parameter! + +declare <16 x double> @d(<16 x double> passthru %a, <16 x i1> mask %M, <16 x double> passthru %b) +; CHECK: Cannot have multiple 'passthru' parameters! diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -315,6 +315,7 @@ initializeVectorization(*Registry); initializeScalarizeMaskedMemIntrinPass(*Registry); initializeExpandReductionsPass(*Registry); + initializeExpandVectorPredicationPass(*Registry); initializeHardwareLoopsPass(*Registry); // Initialize debugging passes. diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -591,6 +591,7 @@ initializePostInlineEntryExitInstrumenterPass(Registry); initializeUnreachableBlockElimLegacyPassPass(Registry); initializeExpandReductionsPass(Registry); + initializeExpandVectorPredicationPass(Registry); initializeWasmEHPreparePass(Registry); initializeWriteBitcodePassPass(Registry); initializeHardwareLoopsPass(Registry); diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -41,6 +41,7 @@ ValueTest.cpp VectorTypesTest.cpp VerifierTest.cpp + VPIntrinsicTest.cpp WaymarkTest.cpp ) diff --git a/llvm/unittests/IR/VPIntrinsicTest.cpp b/llvm/unittests/IR/VPIntrinsicTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/IR/VPIntrinsicTest.cpp @@ -0,0 +1,224 @@ +//===- VPIntrinsicTest.cpp - VPIntrinsic unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +class VPIntrinsicTest : public testing::Test { +protected: + LLVMContext Context; + + VPIntrinsicTest() : Context() {} + + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr CreateVPDeclarationModule() { + return parseAssemblyString( +" declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.minnum.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.maxnum.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare void @llvm.vp.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) " +" declare void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) " +" declare <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) " +" declare <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.smin.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.smax.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.umin.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.umax.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) " +" declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) " +" declare <16 x float> @llvm.vp.vshift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) " +" declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) " +" declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) " +" declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) " +" declare <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float>, <16 x float>, i8, <16 x i1> mask, i32 vlen) " +" declare <8 x i64> @llvm.vp.fptosi.v8i64v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x i64> @llvm.vp.fptoui.v8i64v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.sitofp.v8f64v8i64(<8 x i64>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.uitofp.v8f64v8i64(<8 x i64>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.rint.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.round.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.nearbyint.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.ceil.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.floor.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.trunc.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x float> @llvm.vp.fptrunc.v8f32v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fpext.v8f64v8f32(<8 x float>, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.pow.v8f64(<8 x double>, <8 x double> %y, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.powi.v8f64(<8 x double>, i32 %y, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.sqrt.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.sin.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.cos.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.log.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.log10.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.log2.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.exp.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.exp2.v8f64(<8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) ", + Err, C); + } +}; + +/// Check that VPIntrinsic:canIgnoreVectorLengthParam() returns true +/// if the vector length parameter does not mask off any lanes. +TEST_F(VPIntrinsicTest, CanIgnoreVectorLength) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = + parseAssemblyString( +"declare <256 x i64> @llvm.vp.mul.v256i64(<256 x i64>, <256 x i64>, <256 x i1>, i32)" +"declare @llvm.vp.mul.nxv2i64(, , , i32)" +"declare i32 @llvm.vscale.i32()" +"define void @test_static_vlen( " +" <256 x i64> %i0, %si0," +" <256 x i64> %i1, %si1," +" <256 x i1> %m, %sm, i32 %vl) { " +" %r0 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 %vl)" +" %r1 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 256)" +" %r2 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 0)" +" %r3 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 7)" +" %r4 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 123)" +" %vs = call i32 @llvm.vscale.i32()" +" %vs.i64 = mul i32 %vs, 2" +" %r5 = call @llvm.vp.mul.nxv2i64( %si0, %si1, %sm, i32 %vs.i64)" +" %r6 = call @llvm.vp.mul.nxv2i64( %si0, %si1, %sm, i32 99999)" +" ret void " +"}", + Err, C); + + auto *F = M->getFunction("test_static_vlen"); + assert(F); + + const int NumExpected = 7; + const bool Expected[] = {false, true, false, false, false, true, false}; + int i = 0; + for (auto &I : F->getEntryBlock()) { + VPIntrinsic *VPI = dyn_cast(&I); + if (!VPI) + continue; + + ASSERT_LT(i, NumExpected); + ASSERT_EQ(Expected[i], VPI->canIgnoreVectorLengthParam()); + ++i; + } +} + +/// Check that the argument returned by +/// VPIntrinsic::GetParamPos(Intrinsic::ID) has the expected type. +TEST_F(VPIntrinsicTest, GetParamPos) { + std::unique_ptr M = CreateVPDeclarationModule(); + assert(M); + + for (Function &F : *M) { + ASSERT_TRUE(F.isIntrinsic()); + Optional MaskParamPos = + VPIntrinsic::GetMaskParamPos(F.getIntrinsicID()); + if (MaskParamPos.hasValue()) { + Type *MaskParamType = F.getArg(MaskParamPos.getValue())->getType(); + ASSERT_TRUE(MaskParamType->isVectorTy()); + ASSERT_TRUE(MaskParamType->getVectorElementType()->isIntegerTy(1)); + } + + Optional VecLenParamPos = + VPIntrinsic::GetVectorLengthParamPos(F.getIntrinsicID()); + if (VecLenParamPos.hasValue()) { + Type *VecLenParamType = F.getArg(VecLenParamPos.getValue())->getType(); + ASSERT_TRUE(VecLenParamType->isIntegerTy(32)); + } + + Optional MemPtrParamPos = VPIntrinsic::GetMemoryPointerParamPos(F.getIntrinsicID()); + if (MemPtrParamPos.hasValue()) { + Type *MemPtrParamType = F.getArg(MemPtrParamPos.getValue())->getType(); + ASSERT_TRUE(MemPtrParamType->isPtrOrPtrVectorTy()); + } + + Optional RoundingParamPos = VPIntrinsic::GetRoundingModeParamPos(F.getIntrinsicID()); + if (RoundingParamPos.hasValue()) { + Type *RoundingParamType = F.getArg(RoundingParamPos.getValue())->getType(); + ASSERT_TRUE(RoundingParamType->isMetadataTy()); + } + + Optional ExceptParamPos = VPIntrinsic::GetExceptionBehaviorParamPos(F.getIntrinsicID()); + if (ExceptParamPos.hasValue()) { + Type *ExceptParamType = F.getArg(ExceptParamPos.getValue())->getType(); + ASSERT_TRUE(ExceptParamType->isMetadataTy()); + } + } +} + +/// Check that going from Opcode to VP intrinsic and back results in the same +/// Opcode. +TEST_F(VPIntrinsicTest, OpcodeRoundTrip) { + std::vector Opcodes; + Opcodes.reserve(100); + + { +#define HANDLE_INST(OCNum, OCName, Class) Opcodes.push_back(OCNum); +#include "llvm/IR/Instruction.def" + } + + unsigned FullTripCounts = 0; + for (unsigned OC : Opcodes) { + Intrinsic::ID VPID = VPIntrinsic::GetForOpcode(OC); + // no equivalent VP intrinsic available + if (VPID == Intrinsic::not_intrinsic) + continue; + + unsigned RoundTripOC = VPIntrinsic::GetFunctionalOpcodeForVP(VPID); + // no equivalent Opcode available + if (RoundTripOC == Instruction::Call) + continue; + + ASSERT_EQ(RoundTripOC, OC); + ++FullTripCounts; + } + ASSERT_NE(FullTripCounts, 0u); +} + +} // end anonymous namespace diff --git a/llvm/utils/TableGen/CodeGenIntrinsics.h b/llvm/utils/TableGen/CodeGenIntrinsics.h --- a/llvm/utils/TableGen/CodeGenIntrinsics.h +++ b/llvm/utils/TableGen/CodeGenIntrinsics.h @@ -123,6 +123,9 @@ /// True if the intrinsic is no-return. bool isNoReturn; + /// True if the intrinsic is no-sync. + bool isNoSync; + /// True if the intrinsic is will-return. bool isWillReturn; @@ -146,7 +149,10 @@ ReadOnly, WriteOnly, ReadNone, - ImmArg + ImmArg, + Mask, + VectorLength, + Passthru }; std::vector> ArgumentAttributes; diff --git a/llvm/utils/TableGen/CodeGenTarget.cpp b/llvm/utils/TableGen/CodeGenTarget.cpp --- a/llvm/utils/TableGen/CodeGenTarget.cpp +++ b/llvm/utils/TableGen/CodeGenTarget.cpp @@ -607,6 +607,7 @@ isCommutative = false; canThrow = false; isNoReturn = false; + isNoSync = false; isWillReturn = false; isCold = false; isNoDuplicate = false; @@ -726,8 +727,7 @@ // variants with iAny types; otherwise, if the intrinsic is not // overloaded, all the types can be specified directly. assert(((!TyEl->isSubClassOf("LLVMExtendedType") && - !TyEl->isSubClassOf("LLVMTruncatedType") && - !TyEl->isSubClassOf("LLVMScalarOrSameVectorWidth")) || + !TyEl->isSubClassOf("LLVMTruncatedType")) || VT == MVT::iAny || VT == MVT::vAny) && "Expected iAny or vAny type"); } else @@ -772,6 +772,8 @@ isConvergent = true; else if (Property->getName() == "IntrNoReturn") isNoReturn = true; + else if (Property->getName() == "IntrNoSync") + isNoSync = true; else if (Property->getName() == "IntrWillReturn") isWillReturn = true; else if (Property->getName() == "IntrCold") @@ -789,6 +791,15 @@ } else if (Property->isSubClassOf("Returned")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, Returned)); + } else if (Property->isSubClassOf("VectorLength")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, VectorLength)); + } else if (Property->isSubClassOf("Mask")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Mask)); + } else if (Property->isSubClassOf("Passthru")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Passthru)); } else if (Property->isSubClassOf("ReadOnly")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, ReadOnly)); diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -579,6 +579,9 @@ if (L->isNoReturn != R->isNoReturn) return R->isNoReturn; + if (L->isNoSync != R->isNoSync) + return R->isNoSync; + if (L->isWillReturn != R->isWillReturn) return R->isWillReturn; @@ -684,6 +687,24 @@ OS << "Attribute::Returned"; addComma = true; break; + case CodeGenIntrinsic::VectorLength: + if (addComma) + OS << ","; + OS << "Attribute::VectorLength"; + addComma = true; + break; + case CodeGenIntrinsic::Mask: + if (addComma) + OS << ","; + OS << "Attribute::Mask"; + addComma = true; + break; + case CodeGenIntrinsic::Passthru: + if (addComma) + OS << ","; + OS << "Attribute::Passthru"; + addComma = true; + break; case CodeGenIntrinsic::ReadOnly: if (addComma) OS << ","; @@ -720,8 +741,8 @@ if (!intrinsic.canThrow || (intrinsic.ModRef != CodeGenIntrinsic::ReadWriteMem && !intrinsic.hasSideEffects) || - intrinsic.isNoReturn || intrinsic.isWillReturn || intrinsic.isCold || - intrinsic.isNoDuplicate || intrinsic.isConvergent || + intrinsic.isNoReturn || intrinsic.isNoSync || intrinsic.isWillReturn || + intrinsic.isCold || intrinsic.isNoDuplicate || intrinsic.isConvergent || intrinsic.isSpeculatable) { OS << " const Attribute::AttrKind Atts[] = {"; bool addComma = false; @@ -735,6 +756,12 @@ OS << "Attribute::NoReturn"; addComma = true; } + if (intrinsic.isNoSync) { + if (addComma) + OS << ","; + OS << "Attribute::NoSync"; + addComma = true; + } if (intrinsic.isWillReturn) { if (addComma) OS << ",";