diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -16,7 +16,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" -include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" //===----------------------------------------------------------------------===// // ArmSVE dialect definition @@ -33,76 +32,6 @@ }]; } -//===----------------------------------------------------------------------===// -// ArmSVE type definitions -//===----------------------------------------------------------------------===// - -def ArmSVE_ScalableVectorType : DialectType()">, - "scalable vector type">, - BuildableType<"$_builder.getType()"> { - let description = [{ - `arm_sve.vector` represents vectors that will be processed by a scalable - vector architecture. - }]; -} - -class ArmSVE_Type : TypeDef { } - -def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { - let mnemonic = "vector"; - - let summary = "Scalable vector type"; - - let description = [{ - A type representing scalable length SIMD vectors. Unlike fixed-length SIMD - vectors, whose size is constant and known at compile time, scalable - vectors' length is constant but determined by the specific hardware at - run time. - }]; - - let parameters = (ins - ArrayRefParameter<"int64_t", "Vector shape">:$shape, - "Type":$elementType - ); - - let printer = [{ - $_printer << "<"; - for (int64_t dim : getShape()) - $_printer << dim << 'x'; - $_printer << getElementType() << '>'; - }]; - - let parser = [{ - VectorType vector; - if ($_parser.parseType(vector)) - return Type(); - return get($_ctxt, vector.getShape(), vector.getElementType()); - }]; - - let extraClassDeclaration = [{ - bool hasStaticShape() const { - return llvm::none_of(getShape(), ShapedType::isDynamic); - } - int64_t getNumElements() const { - assert(hasStaticShape() && - "cannot get element count of dynamic shaped type"); - ArrayRef shape = getShape(); - int64_t num = 1; - for (auto dim : shape) - num *= dim; - return num; - } - }]; -} - -//===----------------------------------------------------------------------===// -// Additional LLVM type constraints -//===----------------------------------------------------------------------===// -def LLVMScalableVectorType : - Type()">, - "LLVM dialect scalable vector type">; - //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -110,16 +39,6 @@ class ArmSVE_Op traits = []> : Op {} -class ArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : - LLVM_IntrOpBase overloadedResults=*/[0], - /*list overloadedOperands=*/[], // defined by result overload - /*list traits=*/traits, - /*int numResults=*/1>; - class ArmSVE_IntrBinaryOverloadedOp traits = []> : LLVM_IntrOpBase traits=*/traits, /*int numResults=*/1>; -class ScalableFOp traits = []> : - ArmSVE_Op])> { - let summary = op_description # " for scalable vectors of floats"; - let description = [{ - The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and - returns one scalable vector with the result of the }] # op_description # [{. - }]; - let arguments = (ins - ScalableVectorOf<[AnyFloat]>:$src1, - ScalableVectorOf<[AnyFloat]>:$src2 - ); - let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); - let assemblyFormat = - "$src1 `,` $src2 attr-dict `:` type($src1)"; -} - -class ScalableIOp traits = []> : - ArmSVE_Op])> { - let summary = op_description # " for scalable vectors of integers"; - let description = [{ - The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and - returns one scalable vector with the result of the }] # op_description # [{. - }]; - let arguments = (ins - ScalableVectorOf<[I8, I16, I32, I64]>:$src1, - ScalableVectorOf<[I8, I16, I32, I64]>:$src2 - ); - let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); - let assemblyFormat = - "$src1 `,` $src2 attr-dict `:` type($src1)"; -} - class ScalableMaskedFOp traits = []> : ArmSVE_Op { - let summary = "Load vector scale size"; - let description = [{ - The vector_scale op returns the scale of the scalable vectors, a positive - integer value that is constant at runtime but unknown at compile time. - The scale of the vector indicates the multiplicity of the vectors and - vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to - vector_scale consecutive vector<4xi32>; and an operation on an - !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale - times, once on each <4xi32> segment of the scalable vector. The vector_scale - op can be used to calculate the step in vector-length agnostic (VLA) loops. - }]; - let results = (outs Index:$res); - let assemblyFormat = - "attr-dict `:` type($res)"; -} - -def ScalableLoadOp : ArmSVE_Op<"load">, - Arguments<(ins Arg:$base, Index:$index)>, - Results<(outs ScalableVectorOf<[AnyType]>:$result)> { - let summary = "Load scalable vector from memory"; - let description = [{ - Load a slice of memory into a scalable vector. - }]; - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return base().getType().cast(); - } - }]; - let assemblyFormat = "$base `[` $index `]` attr-dict `:` " - "type($result) `from` type($base)"; -} - -def ScalableStoreOp : ArmSVE_Op<"store">, - Arguments<(ins Arg:$base, Index:$index, - ScalableVectorOf<[AnyType]>:$value)> { - let summary = "Store scalable vector into memory"; - let description = [{ - Store a scalable vector on a slice of memory. - }]; - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return base().getType().cast(); - } - }]; - let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` " - "type($value) `to` type($base)"; -} - -def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>; - -def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>; - -def ScalableSubIOp : ScalableIOp<"subi", "subtraction">; - -def ScalableSubFOp : ScalableFOp<"subf", "subtraction">; - -def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>; - -def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>; - -def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">; - -def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">; - -def ScalableDivFOp : ScalableFOp<"divf", "division">; - def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition", [Commutative]>; @@ -430,189 +245,56 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">; -//===----------------------------------------------------------------------===// -// ScalableCmpFOp -//===----------------------------------------------------------------------===// - -def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands, - TypesMatchWith<"result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { - let summary = "floating-point comparison operation for scalable vectors"; - let description = [{ - The `arm_sve.cmpf` operation compares two scalable vectors of floating point - elements according to the float comparison rules and the predicate specified - by the respective attribute. The predicate defines the type of comparison: - (un)orderedness, (in)equality and signed less/greater than (or equal to) as - well as predicates that are always true or false. The result is a scalable - vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always - treated as signed. The u prefix indicates *unordered* comparison, not - unsigned comparison, so "une" means unordered not equal. For the sake of - readability by humans, custom assembly form for the operation uses a - string-typed attribute for the predicate. The value of this attribute - corresponds to lower-cased name of the predicate constant, e.g., "one" means - "ordered not equal". The string representation of the attribute is merely a - syntactic sugar and is converted to an integer attribute by the parser. - - Example: - - ```mlir - %r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32> - ``` - }]; - let arguments = (ins - Arith_CmpFPredicateAttr:$predicate, - ScalableVectorOf<[AnyFloat]>:$lhs, - ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar - ); - let results = (outs ScalableVectorOf<[I1]>:$result); - - let builders = [ - OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static arith::CmpFPredicate getPredicateByName(StringRef name); - - arith::CmpFPredicate getPredicate() { - return (arith::CmpFPredicate) (*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - -//===----------------------------------------------------------------------===// -// ScalableCmpIOp -//===----------------------------------------------------------------------===// - -def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands, - TypesMatchWith<"result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { - let summary = "integer comparison operation for scalable vectors"; - let description = [{ - The `arm_sve.cmpi` operation compares two scalable vectors of integer - elements according to the predicate specified by the respective attribute. - - The predicate defines the type of comparison: - - - equal (mnemonic: `"eq"`; integer value: `0`) - - not equal (mnemonic: `"ne"`; integer value: `1`) - - signed less than (mnemonic: `"slt"`; integer value: `2`) - - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) - - signed greater than (mnemonic: `"sgt"`; integer value: `4`) - - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) - - unsigned less than (mnemonic: `"ult"`; integer value: `6`) - - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) - - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) - - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) - - Example: - - ```mlir - %r = arm_sve.cmpi uge, %0, %1 : !arm_sve.vector<4xi32> - ``` - }]; - - let arguments = (ins - Arith_CmpIPredicateAttr:$predicate, - ScalableVectorOf<[I8, I16, I32, I64]>:$lhs, - ScalableVectorOf<[I8, I16, I32, I64]>:$rhs - ); - let results = (outs ScalableVectorOf<[I1]>:$result); - - let builders = [ - OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static arith::CmpIPredicate getPredicateByName(StringRef name); - - arith::CmpIPredicate getPredicate() { - return (arith::CmpIPredicate) (*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def UdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udot">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedAddIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"add">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedAddFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fadd">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedMulIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"mul">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedMulFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fmul">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedSubIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sub">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedSubFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fsub">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedSDivIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdiv">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedUDivIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udiv">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedDivFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fdiv">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; - -def VectorScaleIntrOp: - ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; #endif // ARMSVE_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h @@ -21,9 +21,6 @@ #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc" - #define GET_OP_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVE.h.inc" diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td +++ /dev/null @@ -1,53 +0,0 @@ -//===-- ArmSVEOpBase.td - Base op definitions for ArmSVE ---*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the base operation definition file for ArmSVE scalable vector types. -// -//===----------------------------------------------------------------------===// - -#ifndef ARMSVE_OP_BASE -#define ARMSVE_OP_BASE - -//===----------------------------------------------------------------------===// -// ArmSVE scalable vector type constraints -//===----------------------------------------------------------------------===// - -def IsScalableVectorTypePred : - CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; - -class ScalableVectorOf allowedTypes> : - ContainerType, IsScalableVectorTypePred, - "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", - "scalable vector">; - -// Whether the number of elements of a scalable vector is from the given -// `allowedLengths` list -class IsScalableVectorOfLengthPred allowedLengths> : - And<[IsScalableVectorTypePred, - Or().getNumElements() == }] - # allowedlength>)>]>; - -// Any scalable vector where the number of elements is from the given -// `allowedLengths` list -class ScalableVectorOfLength allowedLengths> : Type< - IsScalableVectorOfLengthPred, - " of length " # !interleave(allowedLengths, "/"), - "::mlir::arm_sve::ScalableVectorType">; - -// Any scalable vector where the number of elements is from the given -// `allowedLengths` list and the type is from the given `allowedTypes` list -class ScalableVectorOfLengthAndType allowedLengths, - list allowedTypes> : Type< - And<[ScalableVectorOf.predicate, - ScalableVectorOfLength.predicate]>, - ScalableVectorOf.summary # - ScalableVectorOfLength.summary, - "::mlir::arm_sve::ScalableVectorType">; - -#endif // ARMSVE_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1732,7 +1732,9 @@ let arguments = (ins LLVM_Type, LLVM_Type, LLVM_Type); } -// +/// Create a call to vscale intrinsic. +def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>; + // Atomic operations. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -446,10 +446,21 @@ /// Returns the element count of any LLVM-compatible vector type. llvm::ElementCount getVectorNumElements(Type type); +/// Returns whether a vector type is scalable or not. +bool getIsScalableVectorType(Type vectorType); + +/// Creates an LLVM dialect-compatible vector type with the given element type +/// and length. +Type getVectorType(Type elementType, unsigned numElements, bool isScalable); + /// Creates an LLVM dialect-compatible type with the given element type and /// length. Type getFixedVectorType(Type elementType, unsigned numElements); +/// Creates an LLVM dialect-compatible type with the given element type and +/// length. +Type getScalableVectorType(Type elementType, unsigned numElements); + /// Returns the size of the given primitive LLVM dialect-compatible type /// (including vectors) in bits, for example, the size of i16 is 16 and /// the size of vector<4xi16> is 64. Returns 0 for non-primitive diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2369,4 +2369,27 @@ let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; } +//===----------------------------------------------------------------------===// +// VectorScaleOp +//===----------------------------------------------------------------------===// + +def VectorScaleOp : Vector_Op<"vector_scale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The vector_scale op returns the scale of the scalable vectors, a positive + integer value that is constant at runtime but unknown at compile-time. + The scale of the vector indicates the multiplicity of the vectors and + vector operations. For example, a vector<<4xi32>> is equivalent to + vector_scale consecutive vector<4xi32>; and an operation on a + vector<<4xi32>> is equivalent to performing that operation vector_scale + times, once on each <4xi32> segment of the scalable vector. The vector_scale + op can be used to calculate the step in vector-length agnostic (VLA) loops. + }]; + let results = (outs Index:$res); + let assemblyFormat = + "attr-dict `:` type($res)"; + let verifier = [{ return success(); }]; +} + #endif // VECTOR_OPS diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -887,16 +887,21 @@ Syntax: ``` - vector-type ::= `vector` `<` static-dimension-list vector-element-type `>` + vector-type ::= fixed-length-vector | scalable-length-vector + fixed-length-vector ::= `vector` `<` static-dimension-list vector-element-type `>` + scalable-length-vector ::= `vector` `<<` static-dimension-list vector-element-type `>>` vector-element-type ::= float-type | integer-type | index-type static-dimension-list ::= (decimal-literal `x`)+ ``` - The vector type represents a SIMD style vector, used by target-specific - operation sets like AVX. While the most common use is for 1D vectors (e.g. - vector<16 x f32>) we also support multidimensional registers on targets that - support them (like TPUs). + The vector type represents a SIMD style vector, either fixed-length or + scalable length, used by target-specific operation sets like AVX or SVE. + Fixed-length vectors are represented by single angle brackets (< >), and + scalable-length vectors are represented by double angle brackets (<< >>). + While the most common use is for 1D vectors (e.g. vector<16 x f32>) we + also support multidimensional registers on targets that support them + (like TPUs). Vector shapes must be positive decimal integers. @@ -908,19 +913,24 @@ Examples: ```mlir + // A 2D fixed-length vector of 3x42 i32 elements. vector<3x42xi32> + + // A 1D scalable-length vector that contains a multiple of 4 f32 elements. + vector<<4xf32>> ``` }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "Type":$elementType + "Type":$elementType, + "bool":$isScalable ); let builders = [ TypeBuilderWithInferredContext<(ins - "ArrayRef":$shape, "Type":$elementType + "ArrayRef":$shape, "Type":$elementType, CArg<"bool", "false">:$isScalable ), [{ - return $_get(elementType.getContext(), shape, elementType); + return $_get(elementType.getContext(), shape, elementType, isScalable); }]> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -210,6 +210,14 @@ // Whether a type is a VectorType. def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">; +// Whether a type is a fixed-length VectorType. +def IsFixedVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() && + !$_self.cast().getIsScalable()}]>; + +// Whether a type is a scalable VectorType. +def IsScalableVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() && + $_self.cast().getIsScalable()}]>; + // Whether a type is a TensorType. def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">; @@ -598,6 +606,14 @@ ShapedContainerType; +class FixedVectorOf allowedTypes> : + ShapedContainerType; + +class ScalableVectorOf allowedTypes> : + ShapedContainerType; + // Whether the number of elements of a vector is from the given // `allowedRanks` list class IsVectorOfRankPred allowedRanks> : @@ -630,6 +646,24 @@ == }] # allowedlength>)>]>; +// Whether the number of elements of a fixed-length vector is from the given +// `allowedLengths` list +class IsFixedVectorOfLengthPred allowedLengths> : + And<[IsFixedVectorTypePred, + Or().getNumElements() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a scalable vector is from the given +// `allowedLengths` list +class IsScalableVectorOfLengthPred allowedLengths> : + And<[IsScalableVectorTypePred, + Or().getNumElements() + == }] + # allowedlength>)>]>; + // Any vector where the number of elements is from the given // `allowedLengths` list class VectorOfLength allowedLengths> : Type< @@ -637,6 +671,20 @@ " of length " # !interleave(allowedLengths, "/"), "::mlir::VectorType">; +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list +class FixedVectorOfLength allowedLengths> : Type< + IsFixedVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + +// Any scalable vector where the number of elements is from the given +// `allowedLengths` list +class ScalableVectorOfLength allowedLengths> : Type< + IsScalableVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + // Any vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` // list @@ -647,8 +695,32 @@ VectorOf.summary # VectorOfLength.summary, "::mlir::VectorType">; +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +class FixedVectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[FixedVectorOf.predicate, + FixedVectorOfLength.predicate]>, + FixedVectorOf.summary # + FixedVectorOfLength.summary, + "::mlir::VectorType">; + +// Any scalable vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +class ScalableVectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[ScalableVectorOf.predicate, + ScalableVectorOfLength.predicate]>, + ScalableVectorOf.summary # + ScalableVectorOfLength.summary, + "::mlir::VectorType">; + def AnyVector : VectorOf<[AnyType]>; +def AnyFixedVector : FixedVectorOf<[AnyType]>; + +def AnyScalableVector : ScalableVectorOf<[AnyType]>; + // Shaped types. def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -335,7 +335,8 @@ auto elementType = convertType(type.getElementType()); if (!elementType) return {}; - Type vectorType = VectorType::get(type.getShape().back(), elementType); + Type vectorType = VectorType::get(type.getShape().back(), elementType, + type.getIsScalable()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); auto shape = type.getShape(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -26,13 +26,15 @@ // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); + return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), + tp.getIsScalable()); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().take_back(), tp.getElementType()); + return VectorType::get(tp.getShape().take_back(), tp.getElementType(), + tp.getIsScalable()); } // Helper that picks the proper sequence for inserting. @@ -135,6 +137,10 @@ namespace { +/// Trivial Vector to LLVM conversions +using VectorScaleOpConversion = + OneToOneConvertToLLVMPattern; + /// Conversion pattern for a vector.bitcast. class VectorBitCastOpConversion : public ConvertOpToLLVMPattern { @@ -1040,7 +1046,7 @@ VectorExtractElementOpConversion, VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, VectorPrintOpConversion, - VectorTypeCastOpConversion, + VectorTypeCastOpConversion, VectorScaleOpConversion, VectorLoadStoreConversion, VectorLoadStoreConversion, diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -875,7 +875,8 @@ if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) - return VectorType::get(vectorType.getShape(), i1Type); + return VectorType::get(vectorType.getShape(), i1Type, + vectorType.getIsScalable()); return i1Type; } diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -25,12 +25,6 @@ #include "mlir/Dialect/ArmSVE/ArmSVEDialect.cpp.inc" static Type getI1SameShape(Type type); -static void buildScalableCmpIOp(OpBuilder &build, OperationState &result, - arith::CmpIPredicate predicate, Value lhs, - Value rhs); -static void buildScalableCmpFOp(OpBuilder &build, OperationState &result, - arith::CmpFPredicate predicate, Value lhs, - Value rhs); #define GET_OP_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" @@ -43,31 +37,6 @@ #define GET_OP_LIST #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" - >(); -} - -//===----------------------------------------------------------------------===// -// ScalableVectorType -//===----------------------------------------------------------------------===// - -Type ArmSVEDialect::parseType(DialectAsmParser &parser) const { - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - { - Type genType; - auto parseResult = generatedTypeParser(parser, "vector", genType); - if (parseResult.hasValue()) - return genType; - } - parser.emitError(typeLoc, "unknown type in ArmSVE dialect"); - return Type(); -} - -void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { - if (failed(generatedTypePrinter(type, os))) - llvm_unreachable("unexpected 'arm_sve' type kind"); } //===----------------------------------------------------------------------===// @@ -77,30 +46,8 @@ // Return the scalable vector of the same shape and containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto sVectorType = type.dyn_cast()) - return ScalableVectorType::get(type.getContext(), sVectorType.getShape(), - i1Type); + if (auto sVectorType = type.dyn_cast()) + return VectorType::get(sVectorType.getShape(), i1Type, + /* isScalable = */ true); return nullptr; } - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -static void buildScalableCmpFOp(OpBuilder &build, OperationState &result, - arith::CmpFPredicate predicate, Value lhs, - Value rhs) { - result.addOperands({lhs, rhs}); - result.types.push_back(getI1SameShape(lhs.getType())); - result.addAttribute(ScalableCmpFOp::getPredicateAttrName(), - build.getI64IntegerAttr(static_cast(predicate))); -} - -static void buildScalableCmpIOp(OpBuilder &build, OperationState &result, - arith::CmpIPredicate predicate, Value lhs, - Value rhs) { - result.addOperands({lhs, rhs}); - result.types.push_back(getI1SameShape(lhs.getType())); - result.addAttribute(ScalableCmpIOp::getPredicateAttrName(), - build.getI64IntegerAttr(static_cast(predicate))); -} diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -18,29 +18,6 @@ using namespace mlir; using namespace mlir::arm_sve; -// Extract an LLVM IR type from the LLVM IR dialect type. -static Type unwrap(Type type) { - if (!type) - return nullptr; - auto *mlirContext = type.getContext(); - if (!LLVM::isCompatibleType(type)) - emitError(UnknownLoc::get(mlirContext), - "conversion resulted in a non-LLVM type"); - return type; -} - -static Optional -convertScalableVectorTypeToLLVM(ScalableVectorType svType, - LLVMTypeConverter &converter) { - auto elementType = unwrap(converter.convertType(svType.getElementType())); - if (!elementType) - return {}; - - auto sVectorType = - LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); - return sVectorType; -} - template class ForwardOperands : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -70,22 +47,10 @@ } }; -static Optional addUnrealizedCast(OpBuilder &builder, - ScalableVectorType svType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1 || - !inputs[0].getType().isa()) - return Value(); - return builder.create(loc, svType, inputs) - .getResult(0); -} - using SdotOpLowering = OneToOneConvertToLLVMPattern; using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; -using VectorScaleOpLowering = - OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = OneToOneConvertToLLVMPattern; @@ -114,136 +79,10 @@ OneToOneConvertToLLVMPattern; -// Load operation is lowered to code that obtains a pointer to the indexed -// element and loads from it. -struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto type = loadOp.getMemRefType(); - if (!isConvertibleAndHasIdentityMaps(type)) - return failure(); - - LLVMTypeConverter converter(loadOp.getContext()); - - auto resultType = loadOp.result().getType(); - LLVM::LLVMPointerType llvmDataTypePtr; - if (resultType.isa()) { - llvmDataTypePtr = - LLVM::LLVMPointerType::get(resultType.cast()); - } else if (resultType.isa()) { - llvmDataTypePtr = LLVM::LLVMPointerType::get( - convertScalableVectorTypeToLLVM(resultType.cast(), - converter) - .getValue()); - } - Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(), - adaptor.index(), rewriter); - Value bitCastedPtr = rewriter.create( - loadOp.getLoc(), llvmDataTypePtr, dataPtr); - rewriter.replaceOpWithNewOp(loadOp, bitCastedPtr); - return success(); - } -}; - -// Store operation is lowered to code that obtains a pointer to the indexed -// element, and stores the given value to it. -struct ScalableStoreOpLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto type = storeOp.getMemRefType(); - if (!isConvertibleAndHasIdentityMaps(type)) - return failure(); - - LLVMTypeConverter converter(storeOp.getContext()); - - auto resultType = storeOp.value().getType(); - LLVM::LLVMPointerType llvmDataTypePtr; - if (resultType.isa()) { - llvmDataTypePtr = - LLVM::LLVMPointerType::get(resultType.cast()); - } else if (resultType.isa()) { - llvmDataTypePtr = LLVM::LLVMPointerType::get( - convertScalableVectorTypeToLLVM(resultType.cast(), - converter) - .getValue()); - } - Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(), - adaptor.index(), rewriter); - Value bitCastedPtr = rewriter.create( - storeOp.getLoc(), llvmDataTypePtr, dataPtr); - rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), - bitCastedPtr); - return success(); - } -}; - -static void -populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns) { - // clang-format off - patterns.add, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern - >(converter); - // clang-format on -} - -static void -configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { - // clang-format off - target.addIllegalOp(); - // clang-format on -} - -static void -populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns) { - // clang-format off - patterns.add, - OneToOneConvertToLLVMPattern - >(converter); - // clang-format on -} - -static void -configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { - // clang-format off - target.addIllegalOp(); - // clang-format on -} - /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // Populate conversion patterns - // Remove any ArmSVE-specific types from function signatures and results. - populateFuncOpTypeConversionPattern(patterns, converter); - converter.addConversion([&converter](ScalableVectorType svType) { - return convertScalableVectorTypeToLLVM(svType, converter); - }); - converter.addSourceMaterialization(addUnrealizedCast); // clang-format off patterns.add, @@ -254,7 +93,6 @@ SmmlaOpLowering, UdotOpLowering, UmmlaOpLowering, - VectorScaleOpLowering, ScalableMaskedAddIOpLowering, ScalableMaskedAddFOpLowering, ScalableMaskedSubIOpLowering, @@ -264,11 +102,7 @@ ScalableMaskedSDivIOpLowering, ScalableMaskedUDivIOpLowering, ScalableMaskedDivFOpLowering>(converter); - patterns.add(converter); // clang-format on - populateBasicSVEArithmeticExportPatterns(converter, patterns); - populateSVEMaskGenerationExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( @@ -278,7 +112,6 @@ SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp, - VectorScaleIntrOp, ScalableMaskedAddIIntrOp, ScalableMaskedAddFIntrOp, ScalableMaskedSubIIntrOp, @@ -292,7 +125,6 @@ SmmlaOp, UdotOp, UmmlaOp, - VectorScaleOp, ScalableMaskedAddIOp, ScalableMaskedAddFOp, ScalableMaskedSubIOp, @@ -301,25 +133,6 @@ ScalableMaskedMulFOp, ScalableMaskedSDivIOp, ScalableMaskedUDivIOp, - ScalableMaskedDivFOp, - ScalableLoadOp, - ScalableStoreOp>(); + ScalableMaskedDivFOp>(); // clang-format on - auto hasScalableVectorType = [](TypeRange types) { - for (Type type : types) - if (type.isa()) - return true; - return false; - }; - target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { - return !hasScalableVectorType(op.getType().getInputs()) && - !hasScalableVectorType(op.getType().getResults()); - }); - target.addDynamicallyLegalOp( - [hasScalableVectorType](Operation *op) { - return !hasScalableVectorType(op->getOperandTypes()) && - !hasScalableVectorType(op->getResultTypes()); - }); - configureBasicSVEArithmeticLegalizations(target); - configureSVEMaskGenerationLegalizations(target); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -155,12 +155,14 @@ return parser.emitError(trailingTypeLoc, "expected LLVM dialect-compatible type"); if (LLVM::isCompatibleVectorType(type)) { - if (type.isa()) { - resultType = LLVM::LLVMScalableVectorType::get( - resultType, LLVM::getVectorNumElements(type).getKnownMinValue()); + if (LLVM::getIsScalableVectorType(type)) { + resultType = LLVM::getVectorType( + resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), + /* isScalable = */ true); } else { - resultType = LLVM::getFixedVectorType( - resultType, LLVM::getVectorNumElements(type).getFixedValue()); + resultType = LLVM::getVectorType( + resultType, LLVM::getVectorNumElements(type).getFixedValue(), + /* isScalable = */ false); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -548,7 +548,12 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { return llvm::TypeSwitch(type) - .Case([](auto ty) { + .Case([](auto ty) { + if (ty.getIsScalable()) + return llvm::ElementCount::getScalable(ty.getNumElements()); + return llvm::ElementCount::getFixed(ty.getNumElements()); + }) + .Case([](auto ty) { return llvm::ElementCount::getFixed(ty.getNumElements()); }) .Case([](LLVMScalableVectorType ty) { @@ -559,6 +564,31 @@ }); } +bool mlir::LLVM::getIsScalableVectorType(Type vectorType) { + assert( + (vectorType + .isa()) && + "expected LLVM-compatible vector type"); + return !vectorType.isa() && + (vectorType.isa() || + vectorType.cast().getIsScalable()); +} + +Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, + bool isScalable) { + bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); + bool useBuiltIn = VectorType::isValidElementType(elementType); + (void)useBuiltIn; + assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " + "to be either builtin or LLVM dialect type"); + if (useLLVM) { + if (isScalable) + return LLVMScalableVectorType::get(elementType, numElements); + return LLVMFixedVectorType::get(elementType, numElements); + } + return VectorType::get(numElements, elementType, isScalable); +} + Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); bool useBuiltIn = VectorType::isValidElementType(elementType); @@ -570,6 +600,18 @@ return VectorType::get(numElements, elementType); } +Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { + bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType); + bool useBuiltIn = VectorType::isValidElementType(elementType); + (void)useBuiltIn; + assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector " + "type to be either builtin or LLVM dialect " + "type"); + if (useLLVM) + return LLVMScalableVectorType::get(elementType, numElements); + return VectorType::get(numElements, elementType, /* isScalable =*/true); +} + llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { assert(isCompatibleType(type) && "expected a type compatible with the LLVM dialect"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -537,7 +537,8 @@ if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) - return VectorType::get(vectorType.getShape(), i1Type); + return VectorType::get(vectorType.getShape(), i1Type, + vectorType.getIsScalable()); return i1Type; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1934,10 +1934,14 @@ }) .Case([&](VectorType vectorTy) { os << "vector<"; + if (vectorTy.getIsScalable()) + os << "<"; for (int64_t dim : vectorTy.getShape()) os << dim << 'x'; printType(vectorTy.getElementType()); os << '>'; + if (vectorTy.getIsScalable()) + os << ">"; }) .Case([&](RankedTensorType tensorTy) { os << "tensor<"; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -294,7 +294,8 @@ return RankedTensorType::get(shape, elementType); if (isa()) - return VectorType::get(shape, elementType); + return VectorType::get(shape, elementType, + cast().getIsScalable()); llvm_unreachable("Unhandled ShapedType clone case"); } @@ -317,7 +318,8 @@ return RankedTensorType::get(shape, getElementType()); if (isa()) - return VectorType::get(shape, getElementType()); + return VectorType::get(shape, getElementType(), + cast().getIsScalable()); llvm_unreachable("Unhandled ShapedType clone case"); } @@ -340,7 +342,8 @@ } if (isa()) - return VectorType::get(getShape(), elementType); + return VectorType::get(getShape(), elementType, + cast().getIsScalable()); llvm_unreachable("Unhandled ShapedType clone hit"); } @@ -440,7 +443,8 @@ //===----------------------------------------------------------------------===// LogicalResult VectorType::verify(function_ref emitError, - ArrayRef shape, Type elementType) { + ArrayRef shape, Type elementType, + bool isScalable) { if (shape.empty()) return emitError() << "vector types must have at least one dimension"; diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -453,6 +453,10 @@ if (parseToken(Token::less, "expected '<' in vector type")) return nullptr; + bool isScalable = false; + if (consumeIf(Token::less)) + isScalable = true; + SmallVector dimensions; if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) return nullptr; @@ -468,11 +472,14 @@ auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; + if (isScalable && + parseToken(Token::greater, "expected extra '>' in scalable vector type")) + return nullptr; if (!VectorType::isValidElementType(elementType)) return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; - return VectorType::get(dimensions, elementType); + return VectorType::get(dimensions, elementType, isScalable); } /// Parse a dimension list of a tensor or memref type. This populates the diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -242,10 +242,14 @@ if (auto *arrayTy = dyn_cast(llvmType)) { elementType = arrayTy->getElementType(); numElements = arrayTy->getNumElements(); + } else if (auto fVectorTy = dyn_cast(llvmType)) { + elementType = fVectorTy->getElementType(); + numElements = fVectorTy->getNumElements(); + } else if (auto sVectorTy = dyn_cast(llvmType)) { + elementType = sVectorTy->getElementType(); + numElements = sVectorTy->getMinNumElements(); } else { - auto *vectorTy = cast(llvmType); - elementType = vectorTy->getElementType(); - numElements = vectorTy->getNumElements(); + llvm_unreachable("unrecognized constant vector type"); } // Splat value is a scalar. Extract it only if the element type is not // another sequence type. The recursion terminates because each step removes diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -137,6 +137,9 @@ llvm::Type *translate(VectorType type) { assert(LLVM::isCompatibleVectorType(type) && "expected compatible with LLVM vector type"); + if (type.getIsScalable()) + return llvm::ScalableVectorType::get(translateType(type.getElementType()), + type.getNumElements()); return llvm::FixedVectorType::get(translateType(type.getElementType()), type.getNumElements()); } diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -1,168 +1,168 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s -func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_sdot(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: arm_sve.intr.sdot %0 = arm_sve.sdot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_smmla(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: arm_sve.intr.smmla %0 = arm_sve.smmla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_udot(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: arm_sve.intr.udot %0 = arm_sve.udot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_ummla(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: arm_sve.intr.ummla %0 = arm_sve.ummla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>, - %d: !arm_sve.vector<4xi32>, - %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: llvm.mul {{.*}}: !llvm.vec - %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> - // CHECK: llvm.add {{.*}}: !llvm.vec - %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> - // CHECK: llvm.sub {{.*}}: !llvm.vec - %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32> - // CHECK: llvm.sdiv {{.*}}: !llvm.vec - %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32> - // CHECK: llvm.udiv {{.*}}: !llvm.vec - %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32> - return %4 : !arm_sve.vector<4xi32> +func @arm_sve_arithi(%a: vector<<4xi32>>, + %b: vector<<4xi32>>, + %c: vector<<4xi32>>, + %d: vector<<4xi32>>, + %e: vector<<4xi32>>) -> vector<<4xi32>> { + // CHECK: llvm.mul {{.*}}: vector<<4xi32>> + %0 = arith.muli %a, %b : vector<<4xi32>> + // CHECK: llvm.sub {{.*}}: vector<<4xi32>> + %1 = arith.subi %0, %c : vector<<4xi32>> + // CHECK: llvm.sdiv {{.*}}: vector<<4xi32>> + %2 = arith.divsi %1, %d : vector<<4xi32>> + // CHECK: llvm.udiv {{.*}}: vector<<4xi32>> + %3 = arith.divui %1, %e : vector<<4xi32>> + // CHECK: llvm.add {{.*}}: vector<<4xi32>> + %4 = arith.addi %2, %3 : vector<<4xi32>> + return %4 : vector<<4xi32>> } -func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>, - %d: !arm_sve.vector<4xf32>, - %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { - // CHECK: llvm.fmul {{.*}}: !llvm.vec - %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> - // CHECK: llvm.fadd {{.*}}: !llvm.vec - %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> - // CHECK: llvm.fsub {{.*}}: !llvm.vec - %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32> - // CHECK: llvm.fdiv {{.*}}: !llvm.vec - %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32> - return %3 : !arm_sve.vector<4xf32> +func @arm_sve_arithf(%a: vector<<4xf32>>, + %b: vector<<4xf32>>, + %c: vector<<4xf32>>, + %d: vector<<4xf32>>, + %e: vector<<4xf32>>) -> vector<<4xf32>> { + // CHECK: llvm.fmul {{.*}}: vector<<4xf32>> + %0 = arith.mulf %a, %b : vector<<4xf32>> + // CHECK: llvm.fadd {{.*}}: vector<<4xf32>> + %1 = arith.addf %0, %c : vector<<4xf32>> + // CHECK: llvm.fsub {{.*}}: vector<<4xf32>> + %2 = arith.subf %1, %d : vector<<4xf32>> + // CHECK: llvm.fdiv {{.*}}: vector<<4xf32>> + %3 = arith.divf %2, %e : vector<<4xf32>> + return %3 : vector<<4xf32>> } -func @arm_sve_arithi_masked(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>, - %d: !arm_sve.vector<4xi32>, - %e: !arm_sve.vector<4xi32>, - %mask: !arm_sve.vector<4xi1> - ) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.intr.add{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %0 = arm_sve.masked.addi %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.sub{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %1 = arm_sve.masked.subi %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.mul{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %2 = arm_sve.masked.muli %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.sdiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.udiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - return %4 : !arm_sve.vector<4xi32> +func @arm_sve_arithi_masked(%a: vector<<4xi32>>, + %b: vector<<4xi32>>, + %c: vector<<4xi32>>, + %d: vector<<4xi32>>, + %e: vector<<4xi32>>, + %mask: vector<<4xi1>> + ) -> vector<<4xi32>> { + // CHECK: arm_sve.intr.add{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %0 = arm_sve.masked.addi %mask, %a, %b : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: arm_sve.intr.sub{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %1 = arm_sve.masked.subi %mask, %0, %c : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: arm_sve.intr.mul{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %2 = arm_sve.masked.muli %mask, %1, %d : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: arm_sve.intr.sdiv{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: arm_sve.intr.udiv{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<<4xi1>>, + vector<<4xi32>> + return %4 : vector<<4xi32>> } -func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>, - %d: !arm_sve.vector<4xf32>, - %e: !arm_sve.vector<4xf32>, - %mask: !arm_sve.vector<4xi1> - ) -> !arm_sve.vector<4xf32> { - // CHECK: arm_sve.intr.fadd{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %0 = arm_sve.masked.addf %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.intr.fsub{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %1 = arm_sve.masked.subf %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.intr.fmul{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %2 = arm_sve.masked.mulf %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.intr.fdiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - return %3 : !arm_sve.vector<4xf32> +func @arm_sve_arithf_masked(%a: vector<<4xf32>>, + %b: vector<<4xf32>>, + %c: vector<<4xf32>>, + %d: vector<<4xf32>>, + %e: vector<<4xf32>>, + %mask: vector<<4xi1>> + ) -> vector<<4xf32>> { + // CHECK: arm_sve.intr.fadd{{.*}}: (vector<<4xi1>>, vector<<4xf32>>, vector<<4xf32>>) -> vector<<4xf32>> + %0 = arm_sve.masked.addf %mask, %a, %b : vector<<4xi1>>, + vector<<4xf32>> + // CHECK: arm_sve.intr.fsub{{.*}}: (vector<<4xi1>>, vector<<4xf32>>, vector<<4xf32>>) -> vector<<4xf32>> + %1 = arm_sve.masked.subf %mask, %0, %c : vector<<4xi1>>, + vector<<4xf32>> + // CHECK: arm_sve.intr.fmul{{.*}}: (vector<<4xi1>>, vector<<4xf32>>, vector<<4xf32>>) -> vector<<4xf32>> + %2 = arm_sve.masked.mulf %mask, %1, %d : vector<<4xi1>>, + vector<<4xf32>> + // CHECK: arm_sve.intr.fdiv{{.*}}: (vector<<4xi1>>, vector<<4xf32>>, vector<<4xf32>>) -> vector<<4xf32>> + %3 = arm_sve.masked.divf %mask, %2, %e : vector<<4xi1>>, + vector<<4xf32>> + return %3 : vector<<4xf32>> } -func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>) - -> !arm_sve.vector<4xi1> { - // CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec - %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32> - return %0 : !arm_sve.vector<4xi1> +func @arm_sve_mask_genf(%a: vector<<4xf32>>, + %b: vector<<4xf32>>) + -> vector<<4xi1>> { + // CHECK: llvm.fcmp "oeq" {{.*}}: vector<<4xf32>> + %0 = arith.cmpf oeq, %a, %b : vector<<4xf32>> + return %0 : vector<<4xi1>> } -func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi1> { - // CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec - %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi1> +func @arm_sve_mask_geni(%a: vector<<4xi32>>, + %b: vector<<4xi32>>) + -> vector<<4xi1>> { + // CHECK: llvm.icmp "uge" {{.*}}: vector<<4xi32>> + %0 = arith.cmpi uge, %a, %b : vector<<4xi32>> + return %0 : vector<<4xi1>> } -func @arm_sve_abs_diff(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { - // CHECK: llvm.sub {{.*}}: !llvm.vec - %z = arm_sve.subi %a, %a : !arm_sve.vector<4xi32> - // CHECK: llvm.icmp "sge" {{.*}}: !llvm.vec - %agb = arm_sve.cmpi sge, %a, %b : !arm_sve.vector<4xi32> - // CHECK: llvm.icmp "slt" {{.*}}: !llvm.vec - %bga = arm_sve.cmpi slt, %a, %b : !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %0 = arm_sve.masked.subi %agb, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %1 = arm_sve.masked.subi %bga, %b, %a : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %2 = arm_sve.masked.addi %agb, %z, %0 : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %3 = arm_sve.masked.addi %bga, %2, %1 : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - return %3 : !arm_sve.vector<4xi32> +func @arm_sve_abs_diff(%a: vector<<4xi32>>, + %b: vector<<4xi32>>) + -> vector<<4xi32>> { + // CHECK: llvm.mlir.constant(dense<0> : vector<<4xi32>>) : vector<<4xi32>> + %z = arith.subi %a, %a : vector<<4xi32>> + // CHECK: llvm.icmp "sge" {{.*}}: vector<<4xi32>> + %agb = arith.cmpi sge, %a, %b : vector<<4xi32>> + // CHECK: llvm.icmp "slt" {{.*}}: vector<<4xi32>> + %bga = arith.cmpi slt, %a, %b : vector<<4xi32>> + // CHECK: "arm_sve.intr.sub"{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %0 = arm_sve.masked.subi %agb, %a, %b : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: "arm_sve.intr.sub"{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %1 = arm_sve.masked.subi %bga, %b, %a : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: "arm_sve.intr.add"{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %2 = arm_sve.masked.addi %agb, %z, %0 : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: "arm_sve.intr.add"{{.*}}: (vector<<4xi1>>, vector<<4xi32>>, vector<<4xi32>>) -> vector<<4xi32>> + %3 = arm_sve.masked.addi %bga, %2, %1 : vector<<4xi1>>, + vector<<4xi32>> + return %3 : vector<<4xi32>> } func @get_vector_scale() -> index { - // CHECK: arm_sve.vscale - %0 = arm_sve.vector_scale : index + // CHECK: llvm.intr.vscale + %0 = vector.vector_scale : index return %0 : index } diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir deleted file mode 100644 --- a/mlir/test/Dialect/ArmSVE/memcpy.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s - -// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref, [[DST:%arg[0-9]+]] -func @memcopy(%src : memref, %dst : memref, %size : index) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %vs = arm_sve.vector_scale : index - %step = arith.muli %c4, %vs : index - - // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref to !llvm.struct<(ptr - // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref to !llvm.struct<(ptr - // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} - scf.for %i0 = %c0 to %size step %step { - // CHECK: [[SRCIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64 - // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr - // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr to !llvm.ptr> - // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr> - %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref - // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr - // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr to !llvm.ptr> - // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr> - arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref - } - - return -} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -1,137 +1,137 @@ // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s -func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 +func @arm_sve_sdot(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) -> vector<<4xi32>> { + // CHECK: arm_sve.sdot {{.*}}: vector<<16xi8>> to vector<<4xi32 %0 = arm_sve.sdot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 +func @arm_sve_smmla(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) -> vector<<4xi32>> { + // CHECK: arm_sve.smmla {{.*}}: vector<<16xi8>> to vector<<4xi3 %0 = arm_sve.smmla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 +func @arm_sve_udot(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) -> vector<<4xi32>> { + // CHECK: arm_sve.udot {{.*}}: vector<<16xi8>> to vector<<4xi32 %0 = arm_sve.udot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 +func @arm_sve_ummla(%a: vector<<16xi8>>, + %b: vector<<16xi8>>, + %c: vector<<4xi32>>) -> vector<<4xi32>> { + // CHECK: arm_sve.ummla {{.*}}: vector<<16xi8>> to vector<<4xi3 %0 = arm_sve.ummla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<<16xi8>> to vector<<4xi32>> + return %0 : vector<<4xi32>> } -func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32> - %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> - // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32> - %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> - return %1 : !arm_sve.vector<4xi32> +func @arm_sve_arithi(%a: vector<<4xi32>>, + %b: vector<<4xi32>>, + %c: vector<<4xi32>>) -> vector<<4xi32>> { + // CHECK: muli {{.*}}: vector<<4xi32>> + %0 = arith.muli %a, %b : vector<<4xi32>> + // CHECK: addi {{.*}}: vector<<4xi32>> + %1 = arith.addi %0, %c : vector<<4xi32>> + return %1 : vector<<4xi32>> } -func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { - // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32> - %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> - // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32> - %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> - return %1 : !arm_sve.vector<4xf32> +func @arm_sve_arithf(%a: vector<<4xf32>>, + %b: vector<<4xf32>>, + %c: vector<<4xf32>>) -> vector<<4xf32>> { + // CHECK: mulf {{.*}}: vector<<4xf32>> + %0 = arith.mulf %a, %b : vector<<4xf32>> + // CHECK: addf {{.*}}: vector<<4xf32>> + %1 = arith.addf %0, %c : vector<<4xf32>> + return %1 : vector<<4xf32>> } -func @arm_sve_masked_arithi(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>, - %d: !arm_sve.vector<4xi32>, - %e: !arm_sve.vector<4xi32>, - %mask: !arm_sve.vector<4xi1>) - -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.masked.muli {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %0 = arm_sve.masked.muli %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.masked.addi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %1 = arm_sve.masked.addi %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.masked.subi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %2 = arm_sve.masked.subi %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> +func @arm_sve_masked_arithi(%a: vector<<4xi32>>, + %b: vector<<4xi32>>, + %c: vector<<4xi32>>, + %d: vector<<4xi32>>, + %e: vector<<4xi32>>, + %mask: vector<<4xi1>>) + -> vector<<4xi32>> { + // CHECK: arm_sve.masked.muli {{.*}}: vector<<4xi1>>, vector< + %0 = arm_sve.masked.muli %mask, %a, %b : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: arm_sve.masked.addi {{.*}}: vector<<4xi1>>, vector< + %1 = arm_sve.masked.addi %mask, %0, %c : vector<<4xi1>>, + vector<<4xi32>> + // CHECK: arm_sve.masked.subi {{.*}}: vector<<4xi1>>, vector< + %2 = arm_sve.masked.subi %mask, %1, %d : vector<<4xi1>>, + vector<<4xi32>> // CHECK: arm_sve.masked.divi_signed - %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> + %3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<<4xi1>>, + vector<<4xi32>> // CHECK: arm_sve.masked.divi_unsigned - %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - return %2 : !arm_sve.vector<4xi32> + %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<<4xi1>>, + vector<<4xi32>> + return %2 : vector<<4xi32>> } -func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>, - %d: !arm_sve.vector<4xf32>, - %e: !arm_sve.vector<4xf32>, - %mask: !arm_sve.vector<4xi1>) - -> !arm_sve.vector<4xf32> { - // CHECK: arm_sve.masked.mulf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %0 = arm_sve.masked.mulf %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.masked.addf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %1 = arm_sve.masked.addf %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.masked.subf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %2 = arm_sve.masked.subf %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.masked.divf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - return %3 : !arm_sve.vector<4xf32> +func @arm_sve_masked_arithf(%a: vector<<4xf32>>, + %b: vector<<4xf32>>, + %c: vector<<4xf32>>, + %d: vector<<4xf32>>, + %e: vector<<4xf32>>, + %mask: vector<<4xi1>>) + -> vector<<4xf32>> { + // CHECK: arm_sve.masked.mulf {{.*}}: vector<<4xi1>>, vector< + %0 = arm_sve.masked.mulf %mask, %a, %b : vector<<4xi1>>, + vector<<4xf32>> + // CHECK: arm_sve.masked.addf {{.*}}: vector<<4xi1>>, vector< + %1 = arm_sve.masked.addf %mask, %0, %c : vector<<4xi1>>, + vector<<4xf32>> + // CHECK: arm_sve.masked.subf {{.*}}: vector<<4xi1>>, vector< + %2 = arm_sve.masked.subf %mask, %1, %d : vector<<4xi1>>, + vector<<4xf32>> + // CHECK: arm_sve.masked.divf {{.*}}: vector<<4xi1>>, vector< + %3 = arm_sve.masked.divf %mask, %2, %e : vector<<4xi1>>, + vector<<4xf32>> + return %3 : vector<<4xf32>> } -func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>) - -> !arm_sve.vector<4xi1> { - // CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32> - %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32> - return %0 : !arm_sve.vector<4xi1> +func @arm_sve_mask_genf(%a: vector<<4xf32>>, + %b: vector<<4xf32>>) + -> vector<<4xi1>> { + // CHECK: cmpf oeq, {{.*}}: vector<<4xf32>> + %0 = arith.cmpf oeq, %a, %b : vector<<4xf32>> + return %0 : vector<<4xi1>> } -func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi1> { - // CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32> - %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi1> +func @arm_sve_mask_geni(%a: vector<<4xi32>>, + %b: vector<<4xi32>>) + -> vector<<4xi1>> { + // CHECK: cmpi uge, {{.*}}: vector<<4xi32>> + %0 = arith.cmpi uge, %a, %b : vector<<4xi32>> + return %0 : vector<<4xi1>> } -func @arm_sve_memory(%v: !arm_sve.vector<4xi32>, +func @arm_sve_memory(%v: vector<<4xi32>>, %m: memref) - -> !arm_sve.vector<4xi32> { + -> vector<<4xi32>> { %c0 = arith.constant 0 : index - // CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref - %0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref - // CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref - arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref - return %0 : !arm_sve.vector<4xi32> + // CHECK: vector.load {{.*}}: memref, vector<<4xi32>> + %0 = vector.load %m[%c0] : memref, vector<<4xi32>> + // CHECK: vector.store {{.*}}: memref, vector<<4xi32>> + vector.store %v, %m[%c0] : memref, vector<<4xi32>> + return %0 : vector<<4xi32>> } func @get_vector_scale() -> index { - // CHECK: arm_sve.vector_scale : index - %0 = arm_sve.vector_scale : index + // CHECK: vector.vector_scale : index + %0 = vector.vector_scale : index return %0 : index } diff --git a/mlir/test/Dialect/ArmSVE/scalable-memcpy.mlir b/mlir/test/Dialect/ArmSVE/scalable-memcpy.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/scalable-memcpy.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm | mlir-opt | FileCheck %s + +// CHECK: scalable_memcopy([[SRC:%arg[0-9]+]]: memref, [[DST:%arg[0-9]+]] +func @scalable_memcopy(%src : memref, %dst : memref, %size : index) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %vs = vector.vector_scale : index + %step = arith.muli %c4, %vs : index + // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref to !llvm.struct<(ptr + // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref to !llvm.struct<(ptr + // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} + scf.for %i0 = %c0 to %size step %step { + // CHECK: [[DATAIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64 + // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr + // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr to !llvm.ptr>> + // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]]{{.*}}: !llvm.ptr>> + %0 = vector.load %src[%i0] : memref, vector<<4xf32>> + // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr + // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr to !llvm.ptr>> + // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]]{{.*}}: !llvm.ptr>> + vector.store %0, %dst[%i0] : memref, vector<<4xf32>> + } + + return +} diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -1,193 +1,193 @@ // RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s // CHECK-LABEL: define @arm_sve_sdot -llvm.func @arm_sve_sdot(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_sdot(%arg0: vector<<16xi8>>, + %arg1: vector<<16xi8>>, + %arg2: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<<4xi32>>, vector<<16xi8>>, vector<<16xi8>>) + -> vector<<4xi32>> + llvm.return %0 : vector<<4xi32>> } // CHECK-LABEL: define @arm_sve_smmla -llvm.func @arm_sve_smmla(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_smmla(%arg0: vector<<16xi8>>, + %arg1: vector<<16xi8>>, + %arg2: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<<4xi32>>, vector<<16xi8>>, vector<<16xi8>>) + -> vector<<4xi32>> + llvm.return %0 : vector<<4xi32>> } // CHECK-LABEL: define @arm_sve_udot -llvm.func @arm_sve_udot(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_udot(%arg0: vector<<16xi8>>, + %arg1: vector<<16xi8>>, + %arg2: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<<4xi32>>, vector<<16xi8>>, vector<<16xi8>>) + -> vector<<4xi32>> + llvm.return %0 : vector<<4xi32>> } // CHECK-LABEL: define @arm_sve_ummla -llvm.func @arm_sve_ummla(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_ummla(%arg0: vector<<16xi8>>, + %arg1: vector<<16xi8>>, + %arg2: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<<4xi32>>, vector<<16xi8>>, vector<<16xi8>>) + -> vector<<4xi32>> + llvm.return %0 : vector<<4xi32>> } // CHECK-LABEL: define @arm_sve_arithi -llvm.func @arm_sve_arithi(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithi(%arg0: vector<<4xi32>>, + %arg1: vector<<4xi32>>, + %arg2: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: mul - %0 = llvm.mul %arg0, %arg1 : !llvm.vec + %0 = llvm.mul %arg0, %arg1 : vector<<4xi32>> // CHECK: add - %1 = llvm.add %0, %arg2 : !llvm.vec - llvm.return %1 : !llvm.vec + %1 = llvm.add %0, %arg2 : vector<<4xi32>> + llvm.return %1 : vector<<4xi32>> } // CHECK-LABEL: define @arm_sve_arithf -llvm.func @arm_sve_arithf(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithf(%arg0: vector<<4xf32>>, + %arg1: vector<<4xf32>>, + %arg2: vector<<4xf32>>) + -> vector<<4xf32>> { // CHECK: fmul - %0 = llvm.fmul %arg0, %arg1 : !llvm.vec + %0 = llvm.fmul %arg0, %arg1 : vector<<4xf32>> // CHECK: fadd - %1 = llvm.fadd %0, %arg2 : !llvm.vec - llvm.return %1 : !llvm.vec + %1 = llvm.fadd %0, %arg2 : vector<<4xf32>> + llvm.return %1 : vector<<4xf32>> } // CHECK-LABEL: define @arm_sve_arithi_masked -llvm.func @arm_sve_arithi_masked(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec, - %arg3: !llvm.vec, - %arg4: !llvm.vec, - %arg5: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithi_masked(%arg0: vector<<4xi32>>, + %arg1: vector<<4xi32>>, + %arg2: vector<<4xi32>>, + %arg3: vector<<4xi32>>, + %arg4: vector<<4xi32>>, + %arg5: vector<<4xi1>>) + -> vector<<4xi32>> { // CHECK: call @llvm.aarch64.sve.add.nxv4i32 - %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 - %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.mul.nxv4i32 - %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.sdiv.nxv4i32 - %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.udiv.nxv4i32 - %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec - llvm.return %4 : !llvm.vec + %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> + llvm.return %4 : vector<<4xi32>> } // CHECK-LABEL: define @arm_sve_arithf_masked -llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec, - %arg3: !llvm.vec, - %arg4: !llvm.vec, - %arg5: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithf_masked(%arg0: vector<<4xf32>>, + %arg1: vector<<4xf32>>, + %arg2: vector<<4xf32>>, + %arg3: vector<<4xf32>>, + %arg4: vector<<4xf32>>, + %arg5: vector<<4xi1>>) + -> vector<<4xf32>> { // CHECK: call @llvm.aarch64.sve.fadd.nxv4f32 - %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (vector<<4xi1>>, + vector<<4xf32>>, + vector<<4xf32>>) + -> vector<<4xf32>> // CHECK: call @llvm.aarch64.sve.fsub.nxv4f32 - %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (vector<<4xi1>>, + vector<<4xf32>>, + vector<<4xf32>>) + -> vector<<4xf32>> // CHECK: call @llvm.aarch64.sve.fmul.nxv4f32 - %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (vector<<4xi1>>, + vector<<4xf32>>, + vector<<4xf32>>) + -> vector<<4xf32>> // CHECK: call @llvm.aarch64.sve.fdiv.nxv4f32 - %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec - llvm.return %3 : !llvm.vec + %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (vector<<4xi1>>, + vector<<4xf32>>, + vector<<4xf32>>) + -> vector<<4xf32>> + llvm.return %3 : vector<<4xf32>> } // CHECK-LABEL: define @arm_sve_mask_genf -llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec, - %arg1: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_mask_genf(%arg0: vector<<4xf32>>, + %arg1: vector<<4xf32>>) + -> vector<<4xi1>> { // CHECK: fcmp oeq - %0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec - llvm.return %0 : !llvm.vec + %0 = llvm.fcmp "oeq" %arg0, %arg1 : vector<<4xf32>> + llvm.return %0 : vector<<4xi1>> } // CHECK-LABEL: define @arm_sve_mask_geni -llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec, - %arg1: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_mask_geni(%arg0: vector<<4xi32>>, + %arg1: vector<<4xi32>>) + -> vector<<4xi1>> { // CHECK: icmp uge - %0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec - llvm.return %0 : !llvm.vec + %0 = llvm.icmp "uge" %arg0, %arg1 : vector<<4xi32>> + llvm.return %0 : vector<<4xi1>> } // CHECK-LABEL: define @arm_sve_abs_diff -llvm.func @arm_sve_abs_diff(%arg0: !llvm.vec, - %arg1: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_abs_diff(%arg0: vector<<4xi32>>, + %arg1: vector<<4xi32>>) + -> vector<<4xi32>> { // CHECK: sub - %0 = llvm.sub %arg0, %arg0 : !llvm.vec + %0 = llvm.sub %arg0, %arg0 : vector<<4xi32>> // CHECK: icmp sge - %1 = llvm.icmp "sge" %arg0, %arg1 : !llvm.vec + %1 = llvm.icmp "sge" %arg0, %arg1 : vector<<4xi32>> // CHECK: icmp slt - %2 = llvm.icmp "slt" %arg0, %arg1 : !llvm.vec + %2 = llvm.icmp "slt" %arg0, %arg1 : vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 - %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 - %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.add.nxv4i32 - %5 = "arm_sve.intr.add"(%1, %0, %3) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %5 = "arm_sve.intr.add"(%1, %0, %3) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> // CHECK: call @llvm.aarch64.sve.add.nxv4i32 - %6 = "arm_sve.intr.add"(%2, %5, %4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec - llvm.return %6 : !llvm.vec + %6 = "arm_sve.intr.add"(%2, %5, %4) : (vector<<4xi1>>, + vector<<4xi32>>, + vector<<4xi32>>) + -> vector<<4xi32>> + llvm.return %6 : vector<<4xi32>> } // CHECK-LABEL: define void @memcopy @@ -234,7 +234,7 @@ %12 = llvm.mlir.constant(0 : index) : i64 %13 = llvm.mlir.constant(4 : index) : i64 // CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64() - %14 = "arm_sve.vscale"() : () -> i64 + %14 = "llvm.intr.vscale"() : () -> i64 // CHECK: mul i64 [[VL]], 4 %15 = llvm.mul %14, %13 : i64 llvm.br ^bb1(%12 : i64) @@ -249,9 +249,9 @@ // CHECK: etelementptr float, float* %19 = llvm.getelementptr %18[%16] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: bitcast float* %{{[0-9]+}} to * - %20 = llvm.bitcast %19 : !llvm.ptr to !llvm.ptr> + %20 = llvm.bitcast %19 : !llvm.ptr to !llvm.ptr>> // CHECK: load , * - %21 = llvm.load %20 : !llvm.ptr> + %21 = llvm.load %20 : !llvm.ptr>> // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] } %22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, @@ -259,9 +259,9 @@ // CHECK: getelementptr float, float* %32 %23 = llvm.getelementptr %22[%16] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: bitcast float* %33 to * - %24 = llvm.bitcast %23 : !llvm.ptr to !llvm.ptr> + %24 = llvm.bitcast %23 : !llvm.ptr to !llvm.ptr>> // CHECK: store %{{[0-9]+}}, * %{{[0-9]+}} - llvm.store %21, %24 : !llvm.ptr> + llvm.store %21, %24 : !llvm.ptr>> %25 = llvm.add %16, %15 : i64 llvm.br ^bb1(%25 : i64) ^bb3: @@ -271,6 +271,6 @@ // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64() - %0 = "arm_sve.vscale"() : () -> i64 + %0 = "llvm.intr.vscale"() : () -> i64 llvm.return %0 : i64 }