diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -104,7 +104,14 @@ BuiltinTypeDeclBuilder & addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) { - return addMemberVariable("h", Record->getASTContext().VoidPtrTy, Access); + QualType Ty = Record->getASTContext().VoidPtrTy; + if (Template) { + if (const auto *TTD = dyn_cast( + Template->getTemplateParameters()->getParam(0))) + Ty = Record->getASTContext().getPointerType( + QualType(TTD->getTypeForDecl(), 0)); + } + return addMemberVariable("h", Ty, Access); } BuiltinTypeDeclBuilder & @@ -158,15 +165,25 @@ lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle"); Expr *RCExpr = emitResourceClassExpr(AST, RC); - CallExpr *Call = - CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue, - SourceLocation(), FPOptionsOverride()); + Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue, + SourceLocation(), FPOptionsOverride()); CXXThisExpr *This = new (AST) CXXThisExpr(SourceLocation(), Constructor->getThisType(), true); - MemberExpr *Handle = MemberExpr::CreateImplicit( - AST, This, true, Fields["h"], Fields["h"]->getType(), VK_LValue, - OK_Ordinary); + Expr *Handle = MemberExpr::CreateImplicit(AST, This, true, Fields["h"], + Fields["h"]->getType(), VK_LValue, + OK_Ordinary); + + // If the handle isn't a void pointer, cast the builtin result to the + // correct type. + if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) { + Call = CXXStaticCastExpr::Create( + AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr, + AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()), + FPOptionsOverride(), SourceLocation(), SourceLocation(), + SourceRange()); + } + BinaryOperator *Assign = BinaryOperator::Create( AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary, SourceLocation(), FPOptionsOverride()); @@ -179,6 +196,85 @@ return *this; } + BuiltinTypeDeclBuilder &addArraySubscriptOperators() { + addArraySubscriptOperator(true); + addArraySubscriptOperator(false); + return *this; + } + + BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) { + assert(Fields.count("h") > 0 && + "Subscript operator must be added after the handle."); + + FieldDecl *Handle = Fields["h"]; + ASTContext &AST = Record->getASTContext(); + + assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy && + "Not yet supported for void pointer handles."); + + QualType ElemTy = + QualType(Handle->getType()->getPointeeOrArrayElementType(), 0); + QualType ReturnTy = ElemTy; + + FunctionProtoType::ExtProtoInfo ExtInfo; + + // Subscript operators return references to elements, const makes the + // reference and method const so that the underlying data is not mutable. + ReturnTy = AST.getLValueReferenceType(ReturnTy); + if (IsConst) { + ExtInfo.TypeQuals.addConst(); + ReturnTy.addConst(); + } + + QualType MethodTy = + AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo); + auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation()); + auto *MethodDecl = CXXMethodDecl::Create( + AST, Record, SourceLocation(), + DeclarationNameInfo( + AST.DeclarationNames.getCXXOperatorName(OO_Subscript), + SourceLocation()), + MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified, + SourceLocation()); + + IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier); + auto *IdxParam = ParmVarDecl::Create( + AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(), + &II, AST.UnsignedIntTy, + AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()), + SC_None, nullptr); + MethodDecl->setParams({IdxParam}); + + // Also add the parameter to the function prototype. + auto FnProtoLoc = TSInfo->getTypeLoc().getAs(); + FnProtoLoc.setParam(0, IdxParam); + + auto *This = new (AST) + CXXThisExpr(SourceLocation(), MethodDecl->getThisType(), true); + auto *HandleAccess = MemberExpr::CreateImplicit( + AST, This, true, Handle, Handle->getType(), VK_LValue, OK_Ordinary); + + auto *IndexExpr = DeclRefExpr::Create( + AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false, + DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()), + AST.UnsignedIntTy, VK_PRValue); + + auto *Array = + new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue, + OK_Ordinary, SourceLocation()); + + auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr); + + MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(), + SourceLocation(), + SourceLocation())); + MethodDecl->setLexicalDeclContext(Record); + MethodDecl->setAccess(AccessSpecifier::AS_public); + Record->addDecl(MethodDecl); + + return *this; + } + BuiltinTypeDeclBuilder &startDefinition() { Record->startDefinition(); return *this; @@ -368,6 +464,7 @@ BuiltinTypeDeclBuilder(Record) .addHandleMember() .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV) + .addArraySubscriptOperators() .annotateResourceClass(HLSLResourceAttr::UAV) .completeDefinition(); } diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -2174,7 +2174,7 @@ return QualType(); } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && Loc.isValid()) { Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0; return QualType(); } @@ -2244,7 +2244,7 @@ return QualType(); } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && Loc.isValid()) { Diag(Loc, diag::err_hlsl_pointers_unsupported) << 1; return QualType(); } @@ -3008,7 +3008,7 @@ return QualType(); } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && Loc.isValid()) { Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0; return QualType(); } diff --git a/clang/test/AST/HLSL/RWBuffer-AST.hlsl b/clang/test/AST/HLSL/RWBuffer-AST.hlsl --- a/clang/test/AST/HLSL/RWBuffer-AST.hlsl +++ b/clang/test/AST/HLSL/RWBuffer-AST.hlsl @@ -39,11 +39,30 @@ // CHECK: FinalAttr 0x{{[0-9A-Fa-f]+}} <> Implicit final // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <> Implicit UAV -// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit h 'void *' +// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit h 'element_type *' + +// CHECK: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <> operator[] 'element_type &const (unsigned int) const' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <> Idx 'unsigned int' +// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type' lvalue +// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}} +// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <> 'const RWBuffer *' implicit this +// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int' + +// CHECK-NEXT: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <> operator[] 'element_type &(unsigned int)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <> Idx 'unsigned int' +// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type' lvalue +// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}} +// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <> 'RWBuffer *' implicit this +// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int' + // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <> class RWBuffer definition // CHECK: TemplateArgument type 'float' // CHECK-NEXT: BuiltinType 0x{{[0-9A-Fa-f]+}} 'float' // CHECK-NEXT: FinalAttr 0x{{[0-9A-Fa-f]+}} <> Implicit final // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <> Implicit UAV -// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit referenced h 'void *' +// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit referenced h 'float *' diff --git a/clang/test/CodeGenHLSL/buffer-array-operator.hlsl b/clang/test/CodeGenHLSL/buffer-array-operator.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHLSL/buffer-array-operator.hlsl @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s + +const RWBuffer In; +RWBuffer Out; + +void fn(int Idx) { + Out[Idx] = In[Idx]; +} + +// This test is intended to verify reasonable code generation of the subscript +// operator. In this test case we should be generating both the const and +// non-const operators so we verify both cases. + +// Non-const comes first. +// CHECK: ptr @"??A?$RWBuffer@M@hlsl@@QBAAAMI@Z" +// CHECK: %this1 = load ptr, ptr %this.addr, align 4 +// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0 +// CHECK-NEXT: %0 = load ptr, ptr %h, align 4 +// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4 +// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1 +// CHECK-NEXT: ret ptr %arrayidx + +// Const comes next, and returns the pointer instead of the value. +// CHECK: ptr @"??A?$RWBuffer@M@hlsl@@QAAAAMI@Z" +// CHECK: %this1 = load ptr, ptr %this.addr, align 4 +// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0 +// CHECK-NEXT: %0 = load ptr, ptr %h, align 4 +// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4 +// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1 +// CHECK-NEXT: ret ptr %arrayidx