Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -296,9 +296,10 @@ OptionalAttr:$alignment); let results = (outs Res]>:$res); string llvmBuilder = [{ + auto ty = op.getRes().getType().cast().getElementType(); auto addrSpace = $_resultType->getPointerAddressSpace(); auto *inst = builder.CreateAlloca( - $_resultType->getPointerElementType(), addrSpace, $arraySize); + moduleTranslation.convertType(ty), addrSpace, $arraySize); }] # setAlignmentCode # [{ $res = inst; }]; @@ -337,8 +338,12 @@ else indices.push_back(builder.getInt32(structIndex)); } + auto ptrTy = op.getBase().getType(); + if (LLVM::isCompatibleVectorType(ptrTy)) + ptrTy = LLVM::getVectorElementType(ptrTy); + auto elemTy = ptrTy.cast().getElementType(); $res = builder.CreateGEP( - $base->getType()->getPointerElementType(), $base, indices); + moduleTranslation.convertType(elemTy), $base, indices); }]; let assemblyFormat = [{ $base `[` custom($indices, $structIndices) `]` attr-dict @@ -361,8 +366,10 @@ UnitAttr:$nontemporal); let results = (outs LLVM_LoadableType:$res); string llvmBuilder = [{ + auto accessTy = + op.getAddr().getType().cast().getElementType(); auto *inst = builder.CreateLoad( - $addr->getType()->getPointerElementType(), $addr, $volatile_); + moduleTranslation.convertType(accessTy), $addr, $volatile_); }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode @@ -1588,19 +1595,19 @@ /// columns - Number of columns in matrix (must be a constant) /// stride - Space between columns def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> { - let arguments = (ins LLVM_Type:$data, LLVM_Type:$stride, I1Attr:$isVolatile, - I32Attr:$rows, I32Attr:$columns); + let arguments = (ins LLVM_AnyPointer:$data, LLVM_Type:$stride, + I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); const llvm::DataLayout &dl = builder.GetInsertBlock()->getModule()->getDataLayout(); - llvm::Type *ElemTy = $data->getType()->getPointerElementType(); - llvm::Align align = dl.getABITypeAlign(ElemTy); + auto ty = op.getData().getType().cast().getElementType(); + llvm::Type *llvmTy = moduleTranslation.convertType(ty); + llvm::Align align = dl.getABITypeAlign(llvmTy); $res = mb.CreateColumnMajorLoad( - ElemTy, $data, align, $stride, $isVolatile, $rows, - $columns); + llvmTy, $data, align, $stride, $isVolatile, $rows, $columns); }]; let assemblyFormat = "$data `,` `<` `stride` `=` $stride `>` attr-dict" "`:` type($res) `from` type($data) `stride` type($stride)"; @@ -1615,15 +1622,17 @@ /// columns - Number of columns in matrix (must be a constant) /// stride - Space between columns def LLVM_MatrixColumnMajorStoreOp : LLVM_Op<"intr.matrix.column.major.store"> { - let arguments = (ins LLVM_Type:$matrix, LLVM_Type:$data, LLVM_Type:$stride, - I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); + let arguments = (ins LLVM_Type:$matrix, LLVM_AnyPointer:$data, + LLVM_Type:$stride, I1Attr:$isVolatile, I32Attr:$rows, + I32Attr:$columns); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); const llvm::DataLayout &dl = builder.GetInsertBlock()->getModule()->getDataLayout(); - llvm::Align align = dl.getABITypeAlign( - $data->getType()->getPointerElementType()); + auto ty = op.getData().getType().cast().getElementType(); + llvm::Type *llvmTy = moduleTranslation.convertType(ty); + llvm::Align align = dl.getABITypeAlign(llvmTy); mb.CreateColumnMajorStore( $matrix, $data, align, $stride, $isVolatile, $rows, $columns); @@ -1677,16 +1686,17 @@ /// Create a call to Masked Load intrinsic. def LLVM_MaskedLoadOp : LLVM_Op<"intr.masked.load"> { - let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask, + let arguments = (ins LLVM_AnyPointer:$data, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ - llvm::Type *Ty = $data->getType()->getPointerElementType(); + auto ty = op.getData().getType().cast().getElementType(); + llvm::Type *llvmTy = moduleTranslation.convertType(ty); $res = $pass_thru.empty() ? builder.CreateMaskedLoad( - Ty, $data, llvm::Align($alignment), $mask) : + llvmTy, $data, llvm::Align($alignment), $mask) : builder.CreateMaskedLoad( - Ty, $data, llvm::Align($alignment), $mask, $pass_thru[0]); + llvmTy, $data, llvm::Align($alignment), $mask, $pass_thru[0]); }]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -1707,15 +1717,16 @@ /// Create a call to Masked Gather intrinsic. def LLVM_masked_gather : LLVM_Op<"intr.masked.gather"> { - let arguments = (ins LLVM_Type:$ptrs, LLVM_Type:$mask, + let arguments = (ins LLVM_VectorOf:$ptrs, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ - llvm::VectorType *PtrVecTy = cast($ptrs->getType()); + auto vecTy = op.getPtrs().getType(); + auto ptrTy = LLVM::getVectorElementType(vecTy).cast(); llvm::Type *Ty = llvm::VectorType::get( - PtrVecTy->getElementType()->getPointerElementType(), - PtrVecTy->getElementCount()); + moduleTranslation.convertType(ptrTy.getElementType()), + LLVM::getVectorNumElements(vecTy)); $res = $pass_thru.empty() ? builder.CreateMaskedGather( Ty, $ptrs, llvm::Align($alignment), $mask) : builder.CreateMaskedGather(