diff --git a/compiler/rustc_parse/src/lexer/mod.rs b/compiler/rustc_parse/src/lexer/mod.rs
index 1131f00cb42..034442b798b 100644
--- a/compiler/rustc_parse/src/lexer/mod.rs
+++ b/compiler/rustc_parse/src/lexer/mod.rs
@@ -1,5 +1,6 @@
 use rustc_ast::ast::AttrStyle;
 use rustc_ast::token::{self, CommentKind, Token, TokenKind};
+use rustc_ast::tokenstream::IsJoint;
 use rustc_data_structures::sync::Lrc;
 use rustc_errors::{error_code, Applicability, DiagnosticBuilder, FatalError};
 use rustc_lexer::Base;
@@ -65,42 +66,46 @@ impl<'a> StringReader<'a> {
         self.override_span.unwrap_or_else(|| Span::with_root_ctxt(lo, hi))
     }
 
-    /// Returns the next token, including trivia like whitespace or comments.
-    fn next_token(&mut self) -> Token {
+    /// Returns the next token, and info about preceding whitespace, if any.
+    fn next_token(&mut self) -> (IsJoint, Token) {
+        let mut is_joint = IsJoint::Joint;
+
+        // Skip `#!` at the start of the file
         let start_src_index = self.src_index(self.pos);
         let text: &str = &self.src[start_src_index..self.end_src_index];
-
-        if text.is_empty() {
-            let span = self.mk_sp(self.pos, self.pos);
-            return Token::new(token::Eof, span);
-        }
-
-        {
-            let is_beginning_of_file = self.pos == self.start_pos;
-            if is_beginning_of_file {
-                if let Some(shebang_len) = rustc_lexer::strip_shebang(text) {
-                    let start = self.pos;
-                    self.pos = self.pos + BytePos::from_usize(shebang_len);
-
-                    let sym = self.symbol_from(start + BytePos::from_usize("#!".len()));
-                    let kind = token::Shebang(sym);
-
-                    let span = self.mk_sp(start, self.pos);
-                    return Token::new(kind, span);
-                }
+        let is_beginning_of_file = self.pos == self.start_pos;
+        if is_beginning_of_file {
+            if let Some(shebang_len) = rustc_lexer::strip_shebang(text) {
+                self.pos = self.pos + BytePos::from_usize(shebang_len);
+                is_joint = IsJoint::NonJoint;
             }
         }
 
-        let token = rustc_lexer::first_token(text);
+        // Skip trivial (whitespace & comments) tokens
+        loop {
+            let start_src_index = self.src_index(self.pos);
+            let text: &str = &self.src[start_src_index..self.end_src_index];
 
-        let start = self.pos;
-        self.pos = self.pos + BytePos::from_usize(token.len);
+            if text.is_empty() {
+                let span = self.mk_sp(self.pos, self.pos);
+                return (is_joint, Token::new(token::Eof, span));
+            }
 
-        debug!("try_next_token: {:?}({:?})", token.kind, self.str_from(start));
+            let token = rustc_lexer::first_token(text);
 
-        let kind = self.cook_lexer_token(token.kind, start);
-        let span = self.mk_sp(start, self.pos);
-        Token::new(kind, span)
+            let start = self.pos;
+            self.pos = self.pos + BytePos::from_usize(token.len);
+
+            debug!("next_token: {:?}({:?})", token.kind, self.str_from(start));
+
+            match self.cook_lexer_token(token.kind, start) {
+                Some(kind) => {
+                    let span = self.mk_sp(start, self.pos);
+                    return (is_joint, Token::new(kind, span));
+                }
+                None => is_joint = IsJoint::NonJoint,
+            }
+        }
     }
 
     /// Report a fatal lexical error with a given span.
@@ -140,19 +145,16 @@ impl<'a> StringReader<'a> {
     /// Turns simple `rustc_lexer::TokenKind` enum into a rich
     /// `librustc_ast::TokenKind`. This turns strings into interned
     /// symbols and runs additional validation.
-    fn cook_lexer_token(&self, token: rustc_lexer::TokenKind, start: BytePos) -> TokenKind {
-        match token {
+    fn cook_lexer_token(&self, token: rustc_lexer::TokenKind, start: BytePos) -> Option<TokenKind> {
+        Some(match token {
             rustc_lexer::TokenKind::LineComment { doc_style } => {
-                match doc_style {
-                    Some(doc_style) => {
-                        // Opening delimiter of the length 3 is not included into the symbol.
-                        let content_start = start + BytePos(3);
-                        let content = self.str_from(content_start);
+                // Skip non-doc comments
+                let doc_style = doc_style?;
 
-                        self.cook_doc_comment(content_start, content, CommentKind::Line, doc_style)
-                    }
-                    None => token::Comment,
-                }
+                // Opening delimiter of the length 3 is not included into the symbol.
+                let content_start = start + BytePos(3);
+                let content = self.str_from(content_start);
+                self.cook_doc_comment(content_start, content, CommentKind::Line, doc_style)
             }
             rustc_lexer::TokenKind::BlockComment { doc_style, terminated } => {
                 if !terminated {
@@ -171,20 +173,18 @@ impl<'a> StringReader<'a> {
                         .emit();
                     FatalError.raise();
                 }
-                match doc_style {
-                    Some(doc_style) => {
-                        // Opening delimiter of the length 3 and closing delimiter of the length 2
-                        // are not included into the symbol.
-                        let content_start = start + BytePos(3);
-                        let content_end = self.pos - BytePos(if terminated { 2 } else { 0 });
-                        let content = self.str_from_to(content_start, content_end);
 
-                        self.cook_doc_comment(content_start, content, CommentKind::Block, doc_style)
-                    }
-                    None => token::Comment,
-                }
+                // Skip non-doc comments
+                let doc_style = doc_style?;
+
+                // Opening delimiter of the length 3 and closing delimiter of the length 2
+                // are not included into the symbol.
+                let content_start = start + BytePos(3);
+                let content_end = self.pos - BytePos(if terminated { 2 } else { 0 });
+                let content = self.str_from_to(content_start, content_end);
+                self.cook_doc_comment(content_start, content, CommentKind::Block, doc_style)
             }
-            rustc_lexer::TokenKind::Whitespace => token::Whitespace,
+            rustc_lexer::TokenKind::Whitespace => return None,
             rustc_lexer::TokenKind::Ident | rustc_lexer::TokenKind::RawIdent => {
                 let is_raw_ident = token == rustc_lexer::TokenKind::RawIdent;
                 let mut ident_start = start;
@@ -282,12 +282,11 @@ impl<'a> StringReader<'a> {
                 // this should be inside `rustc_lexer`. However, we should first remove compound
                 // tokens like `<<` from `rustc_lexer`, and then add fancier error recovery to it,
                 // as there will be less overall work to do this way.
-                let token = unicode_chars::check_for_substitution(self, start, c, &mut err)
-                    .unwrap_or_else(|| token::Unknown(self.symbol_from(start)));
+                let token = unicode_chars::check_for_substitution(self, start, c, &mut err);
                 err.emit();
-                token
+                token?
             }
-        }
+        })
     }
 
     fn cook_doc_comment(
@@ -450,12 +449,6 @@ impl<'a> StringReader<'a> {
         self.str_from_to(start, self.pos)
     }
 
-    /// Creates a Symbol from a given offset to the current offset.
-    fn symbol_from(&self, start: BytePos) -> Symbol {
-        debug!("taking an ident from {:?} to {:?}", start, self.pos);
-        Symbol::intern(self.str_from(start))
-    }
-
     /// As symbol_from, with an explicit endpoint.
     fn symbol_from_to(&self, start: BytePos, end: BytePos) -> Symbol {
         debug!("taking an ident from {:?} to {:?}", start, end);
diff --git a/compiler/rustc_parse/src/lexer/tokentrees.rs b/compiler/rustc_parse/src/lexer/tokentrees.rs
index c08659ec9f6..6e13bfb9c9d 100644
--- a/compiler/rustc_parse/src/lexer/tokentrees.rs
+++ b/compiler/rustc_parse/src/lexer/tokentrees.rs
@@ -272,19 +272,9 @@ impl<'a> TokenTreesReader<'a> {
     }
 
     fn real_token(&mut self) {
-        self.joint_to_prev = Joint;
-        loop {
-            let token = self.string_reader.next_token();
-            match token.kind {
-                token::Whitespace | token::Comment | token::Shebang(_) | token::Unknown(_) => {
-                    self.joint_to_prev = NonJoint;
-                }
-                _ => {
-                    self.token = token;
-                    return;
-                }
-            }
-        }
+        let (joint_to_prev, token) = self.string_reader.next_token();
+        self.joint_to_prev = joint_to_prev;
+        self.token = token;
     }
 }