diff --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td --- a/llvm/include/llvm/IR/IntrinsicsARM.td +++ b/llvm/include/llvm/IR/IntrinsicsARM.td @@ -922,7 +922,7 @@ list props = [IntrNoMem]> { def "": Intrinsic; def _predicated: Intrinsic(rets[0]), "llvm_anyvector_ty"), + !if(!eq(rets[0], llvm_anyvector_ty), LLVMMatchType<0>, rets[0])], props>; } diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -231,11 +231,11 @@ def NVVM_MMA_OPS : NVVM_MMA_OPS; -// Returns [1] if this combination of layout/satf is supported, [] otherwise. +// Returns true if this combination of layout/satf is supported; false otherwise. // MMA ops must provide all parameters. Loads and stores -- only frags and layout_a. // The class is used to prevent generation of records for the unsupported variants. // E.g. -// foreach _ = NVVM_MMA_SUPPORTED<...>.ret in = +// if NVVM_MMA_SUPPORTED<...>.ret then // def : FOO<>; // The record will only be defined for supported ops. // class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b="-", int satf=-1> { @@ -261,20 +261,20 @@ # !if(!eq(!size(frags), 4), frags[2].ptx_elt_type # frags[3].ptx_elt_type, "?"); - list ret = !cond( + bit ret = !cond( // Sub-int MMA only supports fixed A/B layout. // b1 does not support .satf. - !eq(mma#":"#satf, "b1:row:col:0") : [1], + !eq(mma#":"#satf, "b1:row:col:0") : true, // mma.m8n8k4 has no .satf modifier. !and(!eq(frags[0].geom, "m8n8k4"), - !ne(satf, 0)): [], + !ne(satf, 0)): false, // mma.m8n8k4 has no C=f32 D=f16 variant. - !eq(gcd, "m8n8k4:f32f16"): [], - !eq(mma, "s4:row:col") : [1], - !eq(mma, "u4:row:col") : [1], - !eq(mma, "s4:row:col") : [1], - !eq(mma, "u4:row:col") : [1], + !eq(gcd, "m8n8k4:f32f16"): false, + !eq(mma, "s4:row:col") : true, + !eq(mma, "u4:row:col") : true, + !eq(mma, "s4:row:col") : true, + !eq(mma, "u4:row:col") : true, // Sub-int load/stores have fixed layout for A and B. !and(!eq(layout_b, "-"), // It's a Load or Store op !or(!eq(ld, "b1:a:row"), @@ -288,13 +288,13 @@ !eq(ld, "u4:a:row"), !eq(ld, "u4:b:col"), !eq(ldf, "u4:c"), - !eq(ldf, "u4:d"))) : [1], + !eq(ldf, "u4:d"))) : true, // All other sub-int ops are not supported. - !eq(t, "b1") : [], - !eq(t, "s4") : [], - !eq(t, "u4") : [], + !eq(t, "b1") : false, + !eq(t, "s4") : false, + !eq(t, "u4") : false, // All other (non sub-int) are OK. - true: [1] + true: true ); } @@ -4120,11 +4120,11 @@ foreach layout = ["row", "col"] in { foreach stride = [0, 1] in { foreach frag = NVVM_MMA_OPS.all_ld_ops in - foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + if NVVM_MMA_SUPPORTED<[frag], layout>.ret then def WMMA_NAME_LDST<"load", frag, layout, stride>.record : NVVM_WMMA_LD; foreach frag = NVVM_MMA_OPS.all_st_ops in - foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + if NVVM_MMA_SUPPORTED<[frag], layout>.ret then def WMMA_NAME_LDST<"store", frag, layout, stride>.record : NVVM_WMMA_ST; } @@ -4143,7 +4143,7 @@ foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - foreach _ = NVVM_MMA_SUPPORTED.ret in { + if NVVM_MMA_SUPPORTED.ret then { def WMMA_NAME_MMA.record : NVVM_WMMA_MMA options, code value = "true"> { code ImpliedCheck = !foldl("false", options, accumulator, option, - !strconcat(accumulator, " || ", !cast(option.KeyPath))); + !strconcat(accumulator, " || ", option.KeyPath)); code ImpliedValue = value; } diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7523,10 +7523,10 @@ foreach space = [".global", ".shared", ""] in { foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { foreach frag = NVVM_MMA_OPS.all_ld_ops in - foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + if NVVM_MMA_SUPPORTED<[frag], layout>.ret then def : WMMA_LOAD, layout, space, stride, addr>; foreach frag = NVVM_MMA_OPS.all_st_ops in - foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + if NVVM_MMA_SUPPORTED<[frag], layout>.ret then def : WMMA_STORE_D, layout, space, stride, addr>; } // addr } // space @@ -7584,7 +7584,7 @@ foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - foreach _ = NVVM_MMA_SUPPORTED.ret in { + if NVVM_MMA_SUPPORTED.ret then { def : WMMA_MMA, WMMA_REGINFO, WMMA_REGINFO, diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -113,7 +113,7 @@ class ScalableVectorOfLength allowedLengths> : Type< IsScalableVectorOfLengthPred, - " of length " # StrJoinInt.result>; + " of length " # !interleave(allowedLengths, "/")>; class ScalableVectorOfLengthAndType allowedLengths, list allowedTypes> : Type< diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -258,10 +258,10 @@ llvm::Function *fn = llvm::Intrinsic::getDeclaration( module, llvm::Intrinsic::}] # enumName # [{, - { }] # StrJoin.lst, ListIntSubst.lst)>.result # [{ + overloadedOperands>.lst), ", ") # [{ }); auto operands = lookupValues(opInst.getOperands()); }] # !if(!gt(numResults, 0), "$res = ", "") @@ -326,7 +326,8 @@ llvm::Function *fn = llvm::Intrinsic::getDeclaration( module, llvm::Intrinsic::vector_reduce_}] # mnem # [{, - { }] # StrJoin.lst>.result # [{ + { }] # !interleave(ListIntSubst.lst, + ", ") # [{ }); auto operands = lookupValues(opInst.getOperands()); llvm::FastMathFlags origFM = builder.getFastMathFlags(); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -176,8 +176,8 @@ // Pack all extensions as a static array and get its reference. let instancePreparation = !if(!empty(extensions), "", "static const ::mlir::spirv::Extension exts[] = {" # - StrJoin.result # + !interleave(!foreach(ext, extensions, + "::mlir::spirv::Extension::" # ext.symbol), ", ") # "}; " # // The following manual ArrayRef constructor call is to satisfy GCC 5. "ArrayRef<::mlir::spirv::Extension> " # @@ -217,8 +217,8 @@ // Pack all capabilities as a static array and get its reference. let instancePreparation = !if(!empty(capabilities), "", "static const ::mlir::spirv::Capability caps[] = {" # - StrJoin.result # + !interleave(!foreach(cap, capabilities, + "::mlir::spirv::Capability::" # cap.symbol), ", ") # "}; " # // The following manual ArrayRef constructor call is to satisfy GCC 5. "ArrayRef<::mlir::spirv::Capability> " # @@ -3025,7 +3025,7 @@ class SignlessOrUnsignedIntOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit signless/unsigned integer">; + !interleave(widths, "/") # "-bit signless/unsigned integer">; def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; def SPV_IsCooperativeMatrixType : diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -29,8 +29,8 @@ ".getStorageTypeIntegralWidth() == " # !head(params)>]>, "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { string name = n; - string asTraitArgsStr = - StrJoinInt.result # !if(signed, ", true", ", false"); + string asTraitArgsStr = !interleave(params, ", ") # + !if(signed, ", true", ", false"); } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -38,16 +38,6 @@ string result = r; } -// TODO: Use !interleave() directly rather than through StrJoin/StrJoinInt. - -// Concatenates a list of strings with a separator (default ", ") -class StrJoin strings, string sep = ", "> : - StrFunc; - -// Concatenates a list of integers into a string with a separator (default ", ") -class StrJoinInt integers, string sep = ", "> : - StrFunc; - //===----------------------------------------------------------------------===// // Predicate definitions //===----------------------------------------------------------------------===// @@ -354,7 +344,7 @@ // Satisfy any of the allowed type's condition Or, !if(!eq(description, ""), - StrJoin.result, + !interleave(!foreach(t, allowedTypes, t.description), " or "), description)>; // Integer types. @@ -371,7 +361,7 @@ class AnyIntOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit integer", + !interleave(widths, "/") # "-bit integer", "::mlir::IntegerType">; def AnyI1 : AnyI<1>; @@ -395,7 +385,7 @@ class SignlessIntOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit signless integer">; + !interleave(widths, "/") # "-bit signless integer">; def I1 : I<1>; def I8 : I<8>; @@ -418,7 +408,7 @@ class SignedIntOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit signed integer">; + !interleave(widths, "/") # "-bit signed integer">; def SI1 : SI<1>; def SI8 : SI<8>; @@ -441,7 +431,7 @@ class UnsignedIntOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit unsigned integer">; + !interleave(widths, "/") # "-bit unsigned integer">; def UI1 : UI<1>; def UI8 : UI<8>; @@ -470,7 +460,7 @@ class FloatOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit float">; + !interleave(widths, "/") # "-bit float">; def F16 : F<16>; def F32 : F<32>; @@ -560,7 +550,7 @@ // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank allowedRanks> : Type< IsVectorOfRankPred, - " of ranks " # StrJoinInt.result, "::mlir::VectorType">; + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list @@ -585,10 +575,9 @@ // `allowedLengths` list class VectorOfLength allowedLengths> : Type< IsVectorOfLengthPred, - " of length " # StrJoinInt.result, + " of length " # !interleave(allowedLengths, "/"), "::mlir::VectorType">; - // Any vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` // list @@ -643,7 +632,7 @@ // Ranked tensor type with one of the specified types and ranks. class TensorRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, - StrJoin.result # " " # + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # TensorOf.description, "::mlir::TensorType">; class 0DTensorOf allowedTypes> : TensorRankOf; @@ -684,7 +673,7 @@ // TODO: Have an easy way to add another constraint to a type. class MemRefRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, - StrJoin.result # " " # + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # MemRefOf.description>; class StaticShapeMemRefOf allowedTypes> @@ -709,7 +698,7 @@ class StridedMemRefRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, - StrJoin.result # " " # + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # MemRefOf.description>; // This represents a generic tuple without any constraints on element type. @@ -1205,7 +1194,7 @@ StringBasedAttr< And<[StrAttr.predicate, Or]>, !if(!empty(description), "allowed string cases: " # - StrJoin.result, + !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "), description)>; // An enum attribute backed by IntegerAttr. @@ -1218,7 +1207,7 @@ EnumAttrInfo, SignlessIntegerAttrBase.result, description)> { + !interleave(!foreach(case, cases, case.value), ", "), description)> { let predicate = And<[ SignlessIntegerAttrBase.predicate, Or]>; @@ -1256,7 +1245,7 @@ I32Attr.predicate, // Make sure we don't have unknown bit set. CPred<"!($_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & (~(" - # StrJoin.result # + # !interleave(!foreach(case, cases, case.value # "u"), "|") # ")))"> ]>; @@ -1347,13 +1336,13 @@ let predicate = And<[ SignlessIntElementsAttr.predicate, CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType().getShape() == " - "::mlir::ArrayRef({" # StrJoinInt.result # "})">]>; + "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">]>; let description = width # "-bit signless int elements attribute of shape [" # - StrJoinInt.result # "]"; + !interleave(dims, ", ") # "]"; let constBuilderCall = "::mlir::DenseIntElementsAttr::get(" - "::mlir::RankedTensorType::get({" # StrJoinInt.result # + "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") # "}, $_builder.getIntegerType(" # width # ")), ::llvm::makeArrayRef($0))"; } @@ -1389,15 +1378,15 @@ // Check that this is ranked and has the specified shape. "$_self.cast<::mlir::DenseFPElementsAttr>().getType().hasRank() && " "$_self.cast<::mlir::DenseFPElementsAttr>().getType().getShape() == " - "::mlir::ArrayRef({" # StrJoinInt.result # "})">, + "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">, width # "-bit float elements attribute of shape [" # - StrJoinInt.result # "]"> { + !interleave(dims, ", ") # "]"> { let storageType = [{ ::mlir::DenseFPElementsAttr }]; let returnType = [{ ::mlir::DenseFPElementsAttr }]; let constBuilderCall = "::mlir::DenseElementsAttr::get(" - "::mlir::RankedTensorType::get({" # StrJoinInt.result # + "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") # "}, $_builder.getF" # width # "Type()), " "::llvm::makeArrayRef($0)).cast<::mlir::DenseFPElementsAttr>()"; let convertFromStorage = "$_self"; @@ -1501,7 +1490,7 @@ DictionaryAttrBase()">, "DictionaryAttr with field(s): " # - StrJoin.result # + !interleave(!foreach(a, attributes, "'" # a.name # "'"), ", ") # " (each field having its own constraints)"> { // Name for this StructAttr. string className = name; @@ -1803,7 +1792,7 @@ : ParamNativeOpTrait<"HasParent", op>; class ParentOneOf ops> - : ParamNativeOpTrait<"HasParent", StrJoin.result>; + : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>; // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the @@ -2170,7 +2159,7 @@ class AllMatchPred values> : CPred<"::llvm::is_splat(::llvm::makeArrayRef({" - # StrJoin.result #"}))">; + # !interleave(values, ", ") #"}))">; class AllMatch values, string description> : PredOpTrait>; @@ -2182,7 +2171,7 @@ class AllMatchSameOperatorTrait names, string operator, string description> : PredOpTrait< - "all of {" # StrJoin.result # "} have same " # description, + "all of {" # !interleave(names, ", ") # "} have same " # description, AllMatchSameOperatorPred> { list values = names; } @@ -2299,7 +2288,7 @@ // 2) the indices are not out of range. class TCopVTEtAreSameAt indices> : CPred< "::llvm::is_splat(::llvm::map_range(" - "::mlir::ArrayRef({" # StrJoinInt.result # "}), " + "::mlir::ArrayRef({" # !interleave(indices, ", ") # "}), " "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); " "}))">;