@@ -176,11 +176,95 @@ AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
176
176
return TTI::PSK_Software;
177
177
}
178
178
179
+ bool AArch64TTIImpl::isWideningInstruction (Type *DstTy, unsigned Opcode,
180
+ ArrayRef<const Value *> Args) {
181
+
182
+ // A helper that returns a vector type from the given type. The number of
183
+ // elements in type Ty determine the vector width.
184
+ auto toVectorTy = [&](Type *ArgTy) {
185
+ return VectorType::get (ArgTy->getScalarType (),
186
+ DstTy->getVectorNumElements ());
187
+ };
188
+
189
+ // Exit early if DstTy is not a vector type whose elements are at least
190
+ // 16-bits wide.
191
+ if (!DstTy->isVectorTy () || DstTy->getScalarSizeInBits () < 16 )
192
+ return false ;
193
+
194
+ // Determine if the operation has a widening variant. We consider both the
195
+ // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
196
+ // instructions.
197
+ //
198
+ // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
199
+ // verify that their extending operands are eliminated during code
200
+ // generation.
201
+ switch (Opcode) {
202
+ case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
203
+ case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
204
+ break ;
205
+ default :
206
+ return false ;
207
+ }
208
+
209
+ // To be a widening instruction (either the "wide" or "long" versions), the
210
+ // second operand must be a sign- or zero extend having a single user. We
211
+ // only consider extends having a single user because they may otherwise not
212
+ // be eliminated.
213
+ if (Args.size () != 2 ||
214
+ (!isa<SExtInst>(Args[1 ]) && !isa<ZExtInst>(Args[1 ])) ||
215
+ !Args[1 ]->hasOneUse ())
216
+ return false ;
217
+ auto *Extend = cast<CastInst>(Args[1 ]);
218
+
219
+ // Legalize the destination type and ensure it can be used in a widening
220
+ // operation.
221
+ auto DstTyL = TLI->getTypeLegalizationCost (DL, DstTy);
222
+ unsigned DstElTySize = DstTyL.second .getScalarSizeInBits ();
223
+ if (!DstTyL.second .isVector () || DstElTySize != DstTy->getScalarSizeInBits ())
224
+ return false ;
225
+
226
+ // Legalize the source type and ensure it can be used in a widening
227
+ // operation.
228
+ Type *SrcTy = toVectorTy (Extend->getSrcTy ());
229
+ auto SrcTyL = TLI->getTypeLegalizationCost (DL, SrcTy);
230
+ unsigned SrcElTySize = SrcTyL.second .getScalarSizeInBits ();
231
+ if (!SrcTyL.second .isVector () || SrcElTySize != SrcTy->getScalarSizeInBits ())
232
+ return false ;
233
+
234
+ // Get the total number of vector elements in the legalized types.
235
+ unsigned NumDstEls = DstTyL.first * DstTyL.second .getVectorNumElements ();
236
+ unsigned NumSrcEls = SrcTyL.first * SrcTyL.second .getVectorNumElements ();
237
+
238
+ // Return true if the legalized types have the same number of vector elements
239
+ // and the destination element type size is twice that of the source type.
240
+ return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
241
+ }
242
+
179
243
int AArch64TTIImpl::getCastInstrCost (unsigned Opcode, Type *Dst, Type *Src,
180
244
const Instruction *I) {
181
245
int ISD = TLI->InstructionOpcodeToISD (Opcode);
182
246
assert (ISD && " Invalid opcode" );
183
247
248
+ // If the cast is observable, and it is used by a widening instruction (e.g.,
249
+ // uaddl, saddw, etc.), it may be free.
250
+ if (I && I->hasOneUse ()) {
251
+ auto *SingleUser = cast<Instruction>(*I->user_begin ());
252
+ SmallVector<const Value *, 4 > Operands (SingleUser->operand_values ());
253
+ if (isWideningInstruction (Dst, SingleUser->getOpcode (), Operands)) {
254
+ // If the cast is the second operand, it is free. We will generate either
255
+ // a "wide" or "long" version of the widening instruction.
256
+ if (I == SingleUser->getOperand (1 ))
257
+ return 0 ;
258
+ // If the cast is not the second operand, it will be free if it looks the
259
+ // same as the second operand. In this case, we will generate a "long"
260
+ // version of the widening instruction.
261
+ if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand (1 )))
262
+ if (I->getOpcode () == Cast->getOpcode () &&
263
+ cast<CastInst>(I)->getSrcTy () == Cast->getSrcTy ())
264
+ return 0 ;
265
+ }
266
+ }
267
+
184
268
EVT SrcTy = TLI->getValueType (DL, Src);
185
269
EVT DstTy = TLI->getValueType (DL, Dst);
186
270
@@ -379,6 +463,16 @@ int AArch64TTIImpl::getArithmeticInstrCost(
379
463
// Legalize the type.
380
464
std::pair<int , MVT> LT = TLI->getTypeLegalizationCost (DL, Ty);
381
465
466
+ // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
467
+ // add in the widening overhead specified by the sub-target. Since the
468
+ // extends feeding widening instructions are performed automatically, they
469
+ // aren't present in the generated code and have a zero cost. By adding a
470
+ // widening overhead here, we attach the total cost of the combined operation
471
+ // to the widening instruction.
472
+ int Cost = 0 ;
473
+ if (isWideningInstruction (Ty, Opcode, Args))
474
+ Cost += ST->getWideningBaseCost ();
475
+
382
476
int ISD = TLI->InstructionOpcodeToISD (Opcode);
383
477
384
478
if (ISD == ISD::SDIV &&
@@ -388,9 +482,9 @@ int AArch64TTIImpl::getArithmeticInstrCost(
388
482
// normally expanded to the sequence ADD + CMP + SELECT + SRA.
389
483
// The OperandValue properties many not be same as that of previous
390
484
// operation; conservatively assume OP_None.
391
- int Cost = getArithmeticInstrCost (Instruction::Add, Ty, Opd1Info, Opd2Info,
392
- TargetTransformInfo::OP_None,
393
- TargetTransformInfo::OP_None);
485
+ Cost + = getArithmeticInstrCost (Instruction::Add, Ty, Opd1Info, Opd2Info,
486
+ TargetTransformInfo::OP_None,
487
+ TargetTransformInfo::OP_None);
394
488
Cost += getArithmeticInstrCost (Instruction::Sub, Ty, Opd1Info, Opd2Info,
395
489
TargetTransformInfo::OP_None,
396
490
TargetTransformInfo::OP_None);
@@ -405,16 +499,16 @@ int AArch64TTIImpl::getArithmeticInstrCost(
405
499
406
500
switch (ISD) {
407
501
default :
408
- return BaseT::getArithmeticInstrCost (Opcode, Ty, Opd1Info, Opd2Info,
409
- Opd1PropInfo, Opd2PropInfo);
502
+ return Cost + BaseT::getArithmeticInstrCost (Opcode, Ty, Opd1Info, Opd2Info,
503
+ Opd1PropInfo, Opd2PropInfo);
410
504
case ISD::ADD:
411
505
case ISD::MUL:
412
506
case ISD::XOR:
413
507
case ISD::OR:
414
508
case ISD::AND:
415
509
// These nodes are marked as 'custom' for combining purposes only.
416
510
// We know that they are legal. See LowerAdd in ISelLowering.
417
- return 1 * LT.first ;
511
+ return (Cost + 1 ) * LT.first ;
418
512
}
419
513
}
420
514
0 commit comments